back to indexStanford XCS224U: NLU I Fantastic Language Models and How to Build Them, Part 1 I Spring 2023
Chapters
0:0
0:57 Addressing the known limitations with BERT
2:0 Core model structure (Clark et al. 2019)
5:49 Generator/Discriminator relationships
8:50 ELECTRA efficiency analyses
12:22 ELECTRA model releases
14:54 From the RNN era
15:45 Transformer-based options
21:9 Trends in model size
22:1 Distillation objectives
24:24 Distillation performance
28:16 Pretraining data
28:35 Current trends
00:00:09.480 |
Let's get started. We have another action-packed day for you. 00:00:19.120 |
uh, slide deck on contextual word representations. 00:00:22.020 |
There are just a few more small things to cover. 00:00:25.320 |
And then Sid is gonna get- get us- help us get hands 00:00:36.080 |
from the website if you wanna follow along and we're gonna 00:00:41.360 |
Electra is a model that came from Stanford from Kevin Clark and collaborators. 00:00:48.520 |
It shows you the kind of design space we're in, 00:00:54.160 |
doing something that was different from what had come before in the space of transformers. 00:00:58.640 |
Last time we talked about some known limitations of the BERT model. 00:01:03.120 |
Most of them identified in the BERT paper itself. 00:01:07.520 |
You know, we- we just wanted more ablation studies, 00:01:16.080 |
With Electra, we're gonna address known limitations two and three. 00:01:23.880 |
the trained vocabulary and the fine-tuned vocabulary 00:01:27.240 |
because of the role of the mask token in training BERT models. 00:01:31.600 |
And the second one which might feel more pressing to you, 00:01:35.040 |
is that BERT is pretty inefficient when it comes to learning from data 00:01:39.000 |
because we mask out or replace about 15% of the tokens. 00:01:43.720 |
And as you recall from the BERT learning objective, 00:01:46.840 |
those are the only tokens that contribute to the learning objective itself. 00:01:55.240 |
more efficient use of all these sequences that we're processing. 00:01:58.640 |
Electra is gonna make some progress on that too. 00:02:04.560 |
and then we'll look at all the other things they did in the paper. 00:02:12.200 |
And the first thing we do is mask out some of those tokens and that could be 00:02:16.080 |
a random sample of 15% of the tokens just like in most work with BERT. 00:02:21.320 |
Then we have what could be literally a BERT model. 00:02:25.880 |
Typically a small one that has a mass language modeling objective. 00:02:29.760 |
And that can produce output sequences as usual. 00:02:33.200 |
However, the twist here is that we're gonna replace some of 00:02:37.880 |
the tokens that came from the input with kind of randomly sampled ones from the MLMs. 00:02:43.480 |
You can see here that we've copied over the and copied over chef, 00:02:49.980 |
That might not have been the most probable output for the model, 00:02:52.880 |
but we're gonna do that replacement step there. 00:02:55.520 |
So what we've created- created here is a sequence that we can call X corrupt, 00:03:02.720 |
And that is the primary job of this generator model. 00:03:06.360 |
At this point, the heart of Electra takes over. 00:03:11.080 |
but we can also talk about it as the Electra model itself in essence. 00:03:14.880 |
The job of the discriminator is to figure out which of 00:03:18.780 |
those tokens were originals and which ones were replacements. 00:03:22.800 |
So that's a kind of contrastive learning objective. 00:03:25.360 |
You can see here that the actual label it's gonna learn from from eight is that it was 00:03:29.560 |
replaced and for the- that it wasn't original even though it was a sampled token. 00:03:34.760 |
And the actual loss for the model is the generator, 00:03:37.800 |
that is the- the typical BERT MLM loss together with this Electra loss with a weighting. 00:03:45.860 |
but there is as I said an asymmetry here in the sense that once we've done 00:03:49.640 |
the pre-training phase we can let the generator fall away entirely and focus just on 00:03:55.360 |
the discriminator as the model that we're gonna use for downstream fine-tuning tasks. 00:04:00.680 |
And so you can see already that we've in a way 00:04:04.040 |
solved the problem of having this weird mask token that comes from 00:04:07.440 |
the pre-training phase because the discriminator never sees mask tokens. 00:04:12.000 |
All it sees are these corrupted inputs and it learns to figure out which ones are 00:04:16.960 |
the corrupted versions and which ones are the originals. 00:04:20.740 |
Which is a different capability intuitively than the one we were imbuing the core BERT model with. 00:04:27.960 |
Right? So for BERT it's kind of like the objective is to figure 00:04:30.720 |
out what was missing from the surrounding context. 00:04:33.720 |
And here it's like trying to figure out which of the words in 00:04:37.120 |
the sequence doesn't belong and which of them do belong. 00:04:46.760 |
Before we dive into the experiments and stuff, 00:04:53.920 |
Yes, I'm wondering what the uses of this model are. 00:04:56.880 |
So it tries to predict which ones have been replaced. 00:05:00.040 |
Like do you- like what applications do you use Electra for? 00:05:03.600 |
For pre-training. Yeah, that's what you got to get your head around. 00:05:08.360 |
So the discriminator is now gonna be our pre-trained artifact. 00:05:14.000 |
and when you do that you're downloading some MLN trained object- uh, 00:05:17.320 |
thing, now you download Electra which is the discriminator. 00:05:21.360 |
And it's been trained to do this distinguishing thing, 00:05:27.600 |
continuing thing from the models we've seen so far. 00:05:30.240 |
The eye-opening thing is that that contrastive objective 00:05:33.500 |
leads to a really good pre-trained state for fine-tuning. 00:05:37.680 |
And we might hope that it's doing it much more efficiently, 00:05:43.840 |
So first, generator-discriminator relationships. 00:05:47.760 |
They observe in the paper that when the generator and the discriminator are the same size, 00:05:52.720 |
they can share all their transformer parameters, 00:05:58.900 |
and that's kind of intriguing that one in the same set of weights would be playing 00:06:02.560 |
the role of the MLN and the discriminator- generator-discriminator. 00:06:09.640 |
the best results from having a generator that is small compared to the discriminator. 00:06:17.360 |
So we've got our glue score along the y-axis. 00:06:19.880 |
This will just be a measure of system quality for them. 00:06:27.080 |
And what they mean by that is the dimensionality of the model in the BERT sense. 00:06:32.040 |
So essentially the size of each one of the layers. 00:06:36.960 |
on this blue line, the best performing model, 00:06:38.840 |
this is where we have 768 as our dimensionality for the discriminator, 00:06:54.640 |
And that's what we mean when we say better to have a small generator and a large discriminator. 00:07:00.000 |
And that kind of U-shaped pattern is repeated across all the different discriminator sizes. 00:07:05.800 |
And that's probably an insight about how this model is working, 00:07:09.080 |
which is the sense that you kind of want the generator to be a little bit of 00:07:12.560 |
a noisy process so that the discriminator has some interesting work to do. 00:07:17.760 |
And by making the discriminator more powerful, 00:07:20.480 |
I guess you're creating that kind of opportunity. 00:07:24.000 |
They also do a lot of work looking at efficiency because one of 00:07:29.400 |
the side goals of the elector paper was to end up with models that were 00:07:33.560 |
overall more efficient in terms of the pre-training compute and in terms of the model size. 00:07:42.920 |
and along the x-axis now we have pre-trained flops, 00:07:45.800 |
which you could just think of as a very low level measure of how 00:07:49.040 |
much compute resources we need to do the pre-training part. 00:07:55.960 |
and it's the best no matter what your computational budget is along the x-axis. 00:08:05.960 |
That is a slightly different objective where the generator is trying to fool 00:08:10.560 |
the discriminator by creating corrupted sequences that are hard for the generator to distinguish. 00:08:18.400 |
but it's less than the kind of more cooperative, 00:08:21.120 |
um, joint objective that I showed you before. 00:08:25.840 |
So the green line is where I start training with BERT, 00:08:29.760 |
and that at a certain point I switch to having also the discriminator loss. 00:08:34.640 |
And at that point, the BERT model is less good for any compute budget, 00:08:38.640 |
whereas the Electra variant starts to do its Electra thing and get better and better. 00:08:46.200 |
all pointing to it being a good and efficient model. 00:08:50.040 |
And then finally, they do a bunch more efficiency analyses. 00:08:53.760 |
So this is that picture that I showed you of the full Electra model before, 00:08:57.160 |
where I have the generator creating corrupted sequences, 00:09:00.480 |
and then the discriminator doing its discriminating part there. 00:09:07.400 |
and this is different from full Electra in the sense that on the right for default Electra, 00:09:12.760 |
we make predictions about all of these tokens, 00:09:19.360 |
we kind of do a BERT-like thing where we're going to assume that the ones that weren't 00:09:22.720 |
part of those corrupted chains there, the sampled part, 00:09:31.040 |
Replace MLM. This is an ablation where actually we drop away the Electra part, 00:09:38.800 |
and we're going to not have the mask token at all. 00:09:43.760 |
there are a few ways that they do this learning. 00:09:46.960 |
and they also do the one where they just replace it with some random tokens here, 00:09:50.620 |
like cook to run, and then the model has to reproduce a new token. 00:10:00.080 |
This is a kind of look at what happens if we don't introduce 00:10:04.720 |
that mask token addressing that question about whether that was disrupting learning. 00:10:10.840 |
This is again just a BERT-based objective over here where instead of turning off 00:10:14.920 |
the objective for these ones here that weren't part of the corrupted sequence, 00:10:23.040 |
if we were making more efficient use of the data, 00:10:27.200 |
Here are the results. So Electra is at the top, 00:10:34.120 |
So that's just BERT learning from all of the tokens, 00:10:36.400 |
and I think that does show that BERT could have been a little better if they had not 00:10:40.640 |
turned off the objective for every single token that wasn't 00:10:43.600 |
part of the masking or corruption for that learning process. 00:10:49.520 |
and that's where we don't have any mask token. 00:10:51.340 |
So there's no fine-tuning pre-trained mismatch. 00:10:56.920 |
So overall, you're seeing these ablations are showing us that every piece of 00:11:01.240 |
Electra is contributing something to the overall performance of the model, 00:11:14.640 |
Well, we're making more efficient use of the data because 00:11:17.480 |
we're getting a learning signal from every token, 00:11:20.000 |
and I guess that would be the important dimension because a funny thing about BERT where 00:11:24.220 |
we turn off the learning for the ones that weren't masked or corrupted, 00:11:28.980 |
is that we still have to do the work of computing them. 00:11:31.420 |
It's just that then they don't become part of the objective, 00:11:34.360 |
and here we're just kind of bringing that in. 00:11:39.420 |
My question was, how is the glue score calculated? 00:11:45.580 |
Some accuracy in language generation afterwards, 00:11:51.580 |
Oh, yeah. So glue is a big multitask classification benchmark. 00:11:58.540 |
maybe biased toward natural language inference. 00:12:01.620 |
The reason they're using it in the paper is just that it has been, 00:12:06.100 |
it has been adopted as a kind of general purpose measure of performance, 00:12:11.020 |
and it's driven a lot of reasoning about what's good and what's bad in the field. 00:12:24.460 |
and that was designed to be quickly trained on a single GPU, 00:12:37.420 |
so like putting our text into some kind of representation space, 00:12:44.700 |
Or is it just like generally glue-wise, it's like that? 00:12:51.860 |
That could kind of queue up some analysis work that you could do for a final project. 00:12:55.980 |
Because I think the insight behind your question is that a lot of the time, 00:12:59.540 |
we reason about these models just based on their performance on something like glue. 00:13:05.520 |
what are their internal representations like, 00:13:07.940 |
and are there places where they're transformatively better or worse? 00:13:11.500 |
That you could tie that back to the fact that the learning objective is different. 00:13:15.700 |
We're doing this discrimination thing as opposed to filling in the blanks in some sense. 00:13:20.140 |
Maybe there are some underlying differences. I love that. 00:13:32.500 |
just quickly because we're going to do more work with seek-to-seek models later. 00:13:36.780 |
We're going to train some of our own from scratch, 00:13:40.420 |
So I thought it would be good to just get them on the table as well. 00:13:44.060 |
Seek-to-seek, here's some natural tasks that fall into the seek-to-seek structure. 00:13:52.620 |
Summarization, big text to hopefully smaller text. 00:13:56.720 |
Freeform question answering, where you go from a question and then maybe you're 00:13:59.980 |
generating as opposed to just extracting an answer. 00:14:05.420 |
Semantic parsing, this is the one we're going to tackle, 00:14:07.980 |
where you go from a sentence to some kind of logical form representing its meaning. 00:14:16.380 |
I think there are lots of problems that are pretty naturally cast as seek-to-seek problems, 00:14:20.980 |
especially when you've got different stuff on the input and the output side. 00:14:27.100 |
Yeah, and the more general class of things we could be talking about would be encoder-decoder, 00:14:32.380 |
where that's just more general in the sense that at that point, 00:14:35.380 |
the input could be a picture you're encoding and the output a text. 00:14:38.980 |
Picture-to-picture, video-to-picture, in principle, 00:14:42.500 |
anything could be happening on the two sides. 00:14:45.740 |
just be the special case where we're looking at sequential data, 00:14:48.620 |
typically language data or computer code or something. 00:14:56.980 |
this is just nice if you hearken back to that era, if you live through it. 00:15:06.140 |
the traditional way with a recurrent neural network, an RNN. 00:15:10.900 |
D coming in, and then it transitions into maybe other parameters, 00:15:14.340 |
and it's trying to produce this new sequence left to right coming out. 00:15:17.820 |
Just to remind you, this is part of the journey the field went on. 00:15:23.540 |
is think a lot about how you would add attention layers in to that RNN. 00:15:31.100 |
This is a schematic diagram hinting at the fact that we were moving into an era, 00:15:39.340 |
where basically that attention layer would do all the work. 00:15:52.320 |
There are a few different ways you could think about them. 00:15:54.800 |
On the left is the one I'm nudging you toward, 00:16:02.100 |
what essentially that means is that when we do encoding, 00:16:05.180 |
we can connect everything to everything else. 00:16:07.140 |
You can think of that as a process of simultaneously 00:16:10.220 |
encoding the entire input with all its connections. 00:16:14.060 |
But as we do decoding for many of these problems, 00:16:19.700 |
The result of that is we can look back to the decoder, 00:16:25.180 |
but for the decoder, we have to do that masking that I 00:16:27.860 |
described with the autoregressive loss last time, 00:16:37.340 |
a decoding step be truly sequential decoding. 00:16:40.740 |
But that's not the only way to think about these problems. 00:16:43.540 |
In fact, I don't want to presuppose that for a sequence-to-sequence thing, 00:16:50.340 |
use a language model and the way you might do 00:16:52.260 |
that is to say I'm just going to encode everything left to right. 00:16:57.620 |
Then a kind of compromise position would be that you would take your language model, 00:17:06.620 |
and then begin your process of decoding without 00:17:09.820 |
explicitly having an encoder part and a decoder part. 00:17:13.980 |
I think all of them are on the table and people are solving 00:17:16.740 |
seek-to-seek tasks right now using all of these variants. 00:17:25.180 |
There are lots out there but these are very prominent ones that you might download. 00:17:32.420 |
exploration of which of these architectures are effective. 00:17:38.460 |
and what they did is an impressive amount of multitask training, 00:17:45.700 |
An innovative thing that they did is have these task prefixes, 00:17:50.620 |
like translate English to German and then an English sentence, 00:18:00.020 |
and that's the model's cue to take that input and condition it very informally on 00:18:05.500 |
that task so that the output behavior is the expected behavior. 00:18:10.180 |
There are lots of T5 models that you can download. 00:18:13.860 |
This is nice for development because some of them are 00:18:15.860 |
very small and some of them are very, very large. 00:18:18.540 |
More recently, there are these FLAN models which took 00:18:21.460 |
a T5 architecture and did a lot of reinforcement learning with 00:18:27.060 |
specialize them to different tasks in interesting ways. 00:18:30.660 |
So that's T5, and then the other one that you often hear about that's very effective is BART. 00:18:41.700 |
and it's really got a BERT style thing on the left, 00:18:50.100 |
and then that autoregressive part if you want to do sequential generation. 00:18:53.940 |
The innovative thing about BART is that the training 00:18:57.500 |
involves a lot of corrupting of that input sequence. 00:19:01.220 |
They tried to like do text infilling of pieces, 00:19:12.520 |
and then the model's objective is to learn how to 00:19:15.100 |
essentially uncorrupt what it got as the input. 00:19:18.100 |
They found that the joint process of this text infilling thing and 00:19:22.260 |
sentence shuffling was the most effective for training BART. 00:19:28.220 |
and when then you- then when you fine-tune with BART, 00:19:32.460 |
you just put in two uncorrupted copies of your sentence, 00:19:35.600 |
and then you could fit your task specific labels on like 00:19:39.140 |
the class token or the final token of the GPT output, 00:19:42.420 |
and for seek-to-seek, you just use it as a standard encoder-decoder. 00:19:46.380 |
And again, the intuition is that the pre-training phase which did all this corruption, 00:19:50.580 |
has helped the model learn what sequences are like. 00:19:53.980 |
And that blends together for me a lot of the intuitions we've seen from MLM, 00:19:58.700 |
and from what we just talked about with Electra. Yeah. 00:20:03.460 |
Um, kind of a question that I asked last week. 00:20:05.300 |
Have any of these models worked with like spelling mistakes? 00:20:09.180 |
Yeah. So BART is a really good option if you want to do spelling correction. 00:20:13.660 |
And I actually think that that might be because spelling correction is kind of as a task, 00:20:21.100 |
trying to learn the uncorrupted version of the output. 00:20:23.580 |
So I think if you want to do grammar correction, 00:20:28.900 |
and you might just think of training from scratch on a model that you know is 00:20:32.900 |
going to be aware of characters for these character level things. Yeah. 00:20:39.820 |
That was where they like removed parts of the text essentially, 00:20:49.140 |
Yeah, where you just, that's more like the BERT style thing where you hide some. 00:20:59.260 |
I just want you all to know about distillation. 00:21:02.220 |
Again, because a theme of this course could be how can we do more with less? 00:21:05.980 |
And distillation is a vision for how we could do more with less. 00:21:09.340 |
Right. We saw this trend in model sizes here where they're getting bigger and bigger, 00:21:13.460 |
and there is some hope that they might now be getting smaller. 00:21:16.980 |
But we should all be pushing to make them ever smaller. 00:21:20.300 |
And one way to think about doing that is distillation. 00:21:22.700 |
And the metaphor here is that we're going to have two models. 00:21:25.460 |
Maybe a really big teacher model that was trained in 00:21:28.540 |
a very expensive way and might run only on a supercomputer, 00:21:34.140 |
And we're going to train that student to mimic the behavior of the teacher. 00:21:38.220 |
And we could do that by just observing the output behavior of the teacher, 00:21:42.380 |
and then trying to get the student to align at the level of the output. 00:21:46.260 |
And that would basically just be treating this teacher as a kind of input output device. 00:21:50.980 |
We could also though think about aligning the internal representations of these two, 00:21:56.020 |
to get a deeper alignment between teacher and student. 00:22:06.060 |
So we could just use our gold data for the task. 00:22:08.940 |
I put that as step zero because you might want it in the mix here, 00:22:13.940 |
We could also learn just from the teacher's output labels. 00:22:19.180 |
but I think the intuition is that the teacher might be doing 00:22:22.260 |
some very complicated regularization that helps the student learn more efficiently. 00:22:27.420 |
So even if there are mistakes in the teacher's behavior, 00:22:32.740 |
You could also think about going one level deeper and using 00:22:37.620 |
so not just the discrete outputs but the whole distribution that the model predicts. 00:22:42.140 |
And that's what they did in one of the original distillation papers. 00:22:45.900 |
You could also tie together the final output states. 00:22:49.180 |
If the two models have the same layer-wise dimensionality, 00:22:56.720 |
a cosine similarity between the output states of teacher and student. 00:23:00.820 |
And now you need to have access to the model itself. 00:23:03.540 |
And this will be much more expensive because you need to 00:23:08.700 |
You could also think about doing this with lots of other hidden states. 00:23:14.500 |
And you could even, this is a paper that we did, 00:23:16.860 |
try to mimic them under different counterfactuals where you kind of 00:23:20.260 |
change around the input representations of the teacher, 00:23:23.180 |
observe the output, and then try to get the student to do 00:23:25.560 |
that to mimic very strange behavior from the teacher. 00:23:30.980 |
And then there are a bunch of other things you can do. 00:23:35.940 |
your big model frozen and the teacher is being updated by the process. 00:23:41.300 |
that's where there are lots of big models maybe doing 00:23:43.180 |
multiple tasks and you try to distill them all at once down into a student. 00:23:49.860 |
Co-distillation is where they're trained jointly, 00:23:55.500 |
That's where both the teacher and the student are learning together simultaneously. 00:23:59.300 |
Might be unnerving in the classroom but effective for a model. 00:24:02.580 |
And then self-distillation is actually where you try to get like usually 00:24:06.580 |
lower parts of the model to be like other parts of the model by 00:24:10.300 |
having them mimic themselves as part of the core model training. 00:24:15.560 |
I guess, of co-distillation where there's only one model and 00:24:18.600 |
you're trying to distill parts of it into other parts. 00:24:26.860 |
And the reason I can be encouraging about this is that as we get better and better at 00:24:30.980 |
distillation we're finding that distilled models are as good or better than the teacher models. 00:24:36.580 |
Maybe for a fraction of the cost and this is especially 00:24:40.020 |
relevant if the model is being used in production on a small device or something. 00:24:44.580 |
So here are just some glue performance numbers that show across a bunch of 00:24:48.140 |
these different papers that with distillation you can still 00:24:51.300 |
get glue performance like the teacher with a tiny model. 00:24:59.100 |
Something really puzzling to me is how can a smaller simple model be able to mimic a teacher 00:25:06.500 |
Couldn't you have just trained the simple model? 00:25:09.100 |
Or is it just that the teacher has access to some point of learning that is easier to 00:25:14.180 |
navigate to for a student but not for a student to get to a node? 00:25:19.860 |
I think something like what you just said has to be right. 00:25:24.260 |
I actually don't- so you're asking about the special case where the teacher just does 00:25:28.900 |
its input output thing and produces a dataset that we train the student on, right? 00:25:33.460 |
And you're asking why is that better than just training the student on your original data? 00:25:40.660 |
I- the best metaphor I can give you is that it is a kind of regularizer. 00:25:44.740 |
So the teacher is doing something very complicated 00:25:47.500 |
and even its mistakes are useful for the student. 00:25:50.740 |
I guess this may be a simple way that I'm understanding it. 00:25:54.460 |
It's okay to make certain mistakes and the teacher has figured out which mistakes you can- 00:26:00.420 |
I like that. I like- that's a beautiful opening line of a paper. 00:26:03.900 |
We need to make it substantive by actually explaining what that means. 00:26:13.220 |
One more question and then I'll just wrap up. 00:26:15.060 |
Do we have some comparisons where the student is, 00:26:17.780 |
I mean, less- less general versus- versus the teacher? 00:26:21.420 |
I mean, does it- does it overfit the data in a sense, 00:26:30.260 |
I don't know. I mean, you would guess less if it has a tiny capacity. 00:26:33.820 |
It won't have as much of a capacity to overfit than the teacher. 00:26:37.780 |
And maybe that's why in some situations the students outperform the teachers. 00:26:49.540 |
Architectures I didn't mention, Transformer, Excel, 00:26:52.340 |
wonderful creative attempt to model long sequences by essentially creating 00:26:56.940 |
a recurrent process across cached versions of earlier parts of the long document you're processing. 00:27:04.140 |
ExcelNet, this is a beautiful and creative attempt to use mask language modeling, sorry, 00:27:10.700 |
an autoregressive language modeling objective but still have 00:27:14.620 |
bidirectional context and they do this by creating all these permutation orders of 00:27:19.580 |
the original sequence so that you can effectively condition on the left and the right, 00:27:30.980 |
DeBerta is an attempt to separate out the word and positional encodings for 00:27:35.900 |
these models and kind of make the word embeddings more like first-class citizens. 00:27:41.100 |
And that's very intuitive for me because that's like showing that we want 00:27:44.460 |
the model to learn some semantics for these things that's separate from their position. 00:27:49.340 |
And they did that by reorganizing the attention mechanisms. 00:27:53.300 |
The known limitations, we did a good job on these except for this final one. 00:27:57.940 |
BERT assumes that the predicted tokens are all independent of each other given the unmasked tokens. 00:28:02.820 |
I gave you that example of masking new in York and it thinking that both of 00:28:07.060 |
those are independent of the other given the surrounding context. 00:28:10.460 |
ExcelNet again addresses that and that might be something that you want to meditate on. 00:28:16.140 |
Pre-training data, here's a whole mess of resources. 00:28:23.260 |
I'm offering these primarily because I think you might want to audit them as 00:28:26.900 |
you observe strange behavior from your large models. 00:28:29.860 |
The data might be the key to figuring out where that behavior came from. 00:28:36.700 |
Autoregressive architecture seem to have taken over, 00:28:39.340 |
but that could be just because everyone is so focused on generation. 00:28:42.940 |
I have an intuition that models like BERT are still 00:28:46.540 |
better if you just want to represent examples as opposed to doing generation. 00:28:50.900 |
Seek-to-seek is still a dominant choice for tasks with that structure. 00:28:54.780 |
Although again, 0.1 might be pushing everyone to just use GPT-3 or 4, 00:29:04.380 |
And then people are still obsessed with scaling up. 00:29:06.860 |
But we might be seeing a counter movement towards smaller models, 00:29:10.380 |
especially with reinforcement learning with human feedback. 00:29:13.340 |
And that's something that we're going to talk about next week and the week after. 00:29:17.660 |
So I kind of restructured a little bit of my talk. 00:29:22.460 |
So like we're only going to get through like part one today, 00:29:26.700 |
how transformers work, and then we're going to talk about the other stuff. 00:29:37.700 |
robotics and kind of channeling one of the core concepts of the class. 00:29:40.820 |
It's all about doing a whole lot with very, very little. 00:29:44.420 |
Like I'm really just working on how we get robots to follow instructions, 00:29:51.740 |
you know, opening a fridge or pouring coffee, things like that. 00:29:59.060 |
it became really, really clear that we needed 00:30:00.860 |
better raw materials, better starting points. 00:30:03.140 |
So I started working on pre-training, first in language, 00:30:08.980 |
and robotics with language, kind of as the central theme. 00:30:12.380 |
So I want to talk today about fantastic language models and how to build them. 00:30:17.100 |
So Richard Feynman is probably not only one of the greatest physicists of all time, 00:30:21.740 |
but he's one of the greatest educators of all time, 00:30:26.420 |
And he has this quote, "What I cannot create, I do not understand." 00:30:31.620 |
And for him, it was really just kind of about building blocks. 00:30:34.340 |
How do I understand what is going on at the lowest level, 00:30:37.060 |
so I can compose them together and figure out what to do next? 00:30:45.620 |
I actually just want to spend the next 12 or so minutes 00:30:51.140 |
building transformers, and how that all happened. 00:30:53.420 |
So it's a practical take on these large-scale language models. 00:30:57.380 |
We're really not going to get to the large-scale bit. 00:31:01.300 |
Again, we're only going to focus on the model architecture, 00:31:02.860 |
but today, the evolution of the transformer, how we got there. 00:31:06.660 |
Training at scale, we'll probably cover some other time. 00:31:12.540 |
very briefly efficient fine-tuning and inference. 00:31:14.340 |
And we have some other great CAs who might actually be talking about this more in depth. 00:31:24.380 |
I trained my first deep learning MNIST model in 2018. 00:31:31.380 |
And with every new model, with every new GPT, 00:31:34.780 |
one, two, three, four, the five that's training right now, 00:31:41.260 |
things that are hidden from plain sight that we never get to see. 00:31:45.380 |
And it's been the job of people like us, students, 00:31:51.580 |
find the insights, find the intuition behind these ideas. 00:31:54.460 |
And in kind of rediscovering those pipelines, 00:31:56.860 |
it's actually our comparative advantage and figuring out, 00:32:00.780 |
okay, so these are how these pieces came to be. 00:32:03.540 |
What do I do next? And so I don't really care about time. 00:32:15.340 |
if I say anything you don't understand, that's the contract. 00:32:19.260 |
and we're just going to kind of go step by step. 00:32:25.580 |
How did this become the bedrock of language modeling? 00:32:34.780 |
How did we get here? So what is the recipe for a good language model? 00:32:39.460 |
We've talked a bit about contextual representations. 00:32:41.740 |
Chris was talking through kind of the various, 00:32:43.420 |
you know, different phase changes in language modeling history. 00:32:45.900 |
Lisa was talking about diffusion language models, 00:32:54.260 |
I need massive amounts of cheap, easy to acquire data. 00:32:57.300 |
We're a language model and we're building these contextual representations 00:33:01.500 |
we want to learn truths about the world from data at scale. 00:33:05.540 |
And to do that at scale, we need data at scale. 00:33:10.580 |
And the other is we need a simple and high throughput way to consume it. 00:33:18.060 |
We need to be able to chew through all of this data as fast as we possibly can 00:33:22.660 |
in the least opinionated way to figure out, you know, 00:33:26.540 |
all of the possible patterns, all of the possible things that could be useful 00:33:29.900 |
for people fine-tuning, generating, using these models for arbitrary things downstream. 00:33:35.380 |
This isn't just applicable to language, it's applicable to pretty much everything. 00:33:38.580 |
So vision does this, video does this, video and language, 00:33:41.820 |
vision and language, language and robotics, all of them follow a similar strategy. 00:33:48.100 |
So, right, so simple in that it's natural to scale the approach with data as we get, 00:33:53.300 |
you know, go from 300 billion to 600 billion tokens, 00:33:56.340 |
maybe make the model bigger to handle that in a pretty simple way. 00:34:02.460 |
The training, the way that we actually ingest this data should be fast and parallelizable 00:34:07.940 |
and we should be, you know, making the most of our hardware. 00:34:10.060 |
If we're going to run a data center with, I don't know, 512 GPUs 00:34:14.340 |
with each 8 GPU box costing $120,000, I'd better be getting my money's worth at the end of the day. 00:34:20.740 |
And the consumption part, right, like this minimal assumptions on relationships, 00:34:25.940 |
the less opinionated I am about how different parts of my data is connected, 00:34:30.540 |
the more I can learn given the first thing, massive amounts of data at scale. 00:34:36.980 |
So, kind of like figure out how we got to the transformer, I want to kind of wind time back 00:34:41.700 |
to kind of what Chris was alluding to earlier with RNNs, right. 00:34:46.020 |
So, this is an RNN model, kind of complicated, but it's from 224M. 00:34:56.500 |
And it's this very powerful class of model in theory, right. 00:35:00.820 |
I am ingesting arbitrary length sequences left to right 00:35:04.460 |
and I'm learning arbitrary patterns around them. 00:35:07.820 |
People decided later on to, you know, add these attention mechanisms on top of the sequence 00:35:13.340 |
to sequence RNN models to figure out like how to kind of sharpen their focus 00:35:19.940 |
So, the strengths are like I get to handle arbitrary long context and like we see kind 00:35:23.700 |
of the first semblance of attention appear kind of very motivated 00:35:27.980 |
by like the way we do language translation, right. 00:35:30.020 |
Like when I'm translating word by word, there are certain words in the input 00:35:32.580 |
that are going to matter, I'm going to sharpen my focus too. 00:35:37.540 |
They're not the most scalable, producing the next token requires me 00:35:44.260 |
I can't really make them deeper without training stability going to pieces. 00:35:51.140 |
And so, chewing through a large amount of data with an RNN is hard. 00:35:54.940 |
Some people refuse to believe that and they've actually done immense work in trying to scale 00:35:59.740 |
up RNNs, make them more parallelizable using lots and lots 00:36:04.940 |
And then separately, kind of from the vision community that kind of bled 00:36:09.820 |
into the language community, we have convolutional neural networks. 00:36:13.100 |
And this is from a other course from Lena Vojta about using CNNs for language modeling. 00:36:21.500 |
And the idea here is we have this ability to do immense, deep, parallelizable training 00:36:29.060 |
by kind of taking these like little windows, right. 00:36:32.180 |
I'm going to look at, you know, each layer is only going to look at like the, you know, 00:36:34.540 |
three word contexts at a time and going to give me representation. 00:36:37.460 |
But if I stack this enough times and I have these residual connections that combine, you know, 00:36:42.220 |
earlier inputs with later inputs, by the time I'm like 10 layers deep, 00:36:47.740 |
But I need that depth and that's kind of a drawback. 00:36:52.260 |
But there are like really cool, powerful ideas here. 00:36:54.180 |
And I'd actually say that the transformers have way more to do with CNNs and the way 00:36:57.540 |
that they behave than the way RNNs behave, right. 00:37:00.940 |
So we have this idea of a CNN layer kind of having multiple filters, multiple kernels, 00:37:06.380 |
different ways of looking at and extracting features from an image or features from text. 00:37:12.340 |
You have, you know, this ability to kind of scale depth using these residual connections. 00:37:17.380 |
The deepest networks that we had, you know, from 2012, 2015, 00:37:23.940 |
ResNet 151 isn't called 151 because it's, you know, the 151st edition of the ResNet. 00:37:39.700 |
Every little window that I see at every layer can be computed completely independently 00:37:44.380 |
of every other layer, which is really, really great for modern hardware, modern GPUs. 00:37:51.460 |
So looking at this, seems like CNNs are cool. 00:37:57.060 |
There's a natural question, which is like how do you do better? 00:38:00.700 |
This is the picture from Chris's slides that Lisa also used. 00:38:07.060 |
Like what does self-attention mean in a transformer? 00:38:12.500 |
So one idea, like one key component, like one missing component for how you get from a CNN 00:38:19.180 |
to an RNN and an RNN to a transformer is the idea 00:38:23.140 |
that each individual token is its own query key and value. 00:38:27.220 |
It's its own entity that can be used to shape the representations of all of the other tokens. 00:38:33.500 |
Right. So I'm going to turn this word "the" into its own query key and value. 00:38:36.500 |
I'm going to use the attention from the RNNs, 00:38:39.540 |
and I'm going to use the depth parallelizability scaling from the CNNs. 00:38:44.940 |
And then the multi-headed part of self-attention is exactly like what the, you know, 00:38:49.380 |
different convolutional filters, the different kernels are doing in a CNN. 00:38:53.340 |
It's giving you different perspectives on that same token, different ways to come up with queries, 00:38:57.780 |
different insights into like how we can use, you know, a single token and come 00:39:01.900 |
up with multiple different representations and fuse them together. 00:39:09.580 |
It's a kind of very terse description of like what multi-headed self-attention looks like. 00:39:13.900 |
Key parts here are this, you know, little bit where we kind of project an input sequence 00:39:21.020 |
And then we're just going to rearrange them kind of in some way. 00:39:24.940 |
And the important part here is that like we are rearranging them in a way that splits them 00:39:29.380 |
into these different views, these different heads, where each head has some fixed dimension, 00:39:34.060 |
which is like the key dimension of the transformer. 00:39:36.620 |
And then we have these query keys and values that we're then going to use 00:39:40.580 |
for this attention operation which comes directly from the RNN literature. 00:39:44.740 |
Right? It is really just this dot product between queries and keys. 00:39:48.820 |
That is a very complicated way of saying that's a matrix multiply. 00:39:53.300 |
And then we're going to just project them and combine them back into our tensor of, you know, 00:39:59.700 |
batch size by sequence length by embedding dimension. 00:40:03.340 |
>> There's a nice subtlety here I think that caught me off guard at one point. 00:40:08.620 |
When you look at these models, like you download BERT and it says multi-headed attention 00:40:12.860 |
or whatever, there's only one set of weights even though it's multi-headed. 00:40:17.780 |
Do you want to unpack that for us a little bit? 00:40:19.980 |
>> Yeah. So the convolutional kernel has kind of this really nice way of expressing 00:40:26.940 |
like I have multiple resolutions of an image. 00:40:29.300 |
Right? It's like that depth channel of a convolutional filter. 00:40:32.300 |
And if you can unpack like the conv 2D layer in PyTorch, you kind of see that come out. 00:40:38.860 |
You see just like this one big weight matrix that is literally like this, you know, dimensionality, 00:40:43.900 |
you know, embed dimension by three times embed dimension is usually what it is in BERT or GPT. 00:40:49.340 |
That three is the way you split it into queries, keys, and values. 00:40:55.060 |
But we're actually going to kind of like take that vector that is like, you know, embed dim size 00:41:00.300 |
and just chunk it up into each of these different filters. 00:41:03.300 |
Right? So rather than make those filters explicit as by providing them as a parameter 00:41:07.620 |
that defines some weight layer, for efficiency purposes, 00:41:10.540 |
we're actually just going to treat it all as one matrix and then just chunk it 00:41:13.580 |
up as we're doing the linear algebra operations. 00:41:18.620 |
>> So you're chunking it twice, the times three is queries, keys, and values. 00:41:22.420 |
And then number of heads is further chunking each one of those. 00:41:25.260 |
>> Yeah. So one of the kind of like the key rules that like no one ever tells you 00:41:28.820 |
about transformers is that your number of heads has 00:41:31.940 |
to evenly divide your transformer hidden dimension. 00:41:35.860 |
That's usually a check that is explicitly done in the code for training like BERT or GPT. 00:41:40.580 |
And usually the code doesn't work if that doesn't happen. 00:41:43.220 |
And that's kind of how you get away with a lot of these evisions tricks. 00:41:46.780 |
>> So we should do a whole course on broadcasting. 00:41:49.700 |
>> Before you even start this so that you can do this mess of things. 00:41:52.740 |
>> Yeah. And there's a great professor at Cornell Tech, Sasha Rush, who kind of has like a bunch 00:41:57.380 |
of like tutorials on just like basic like broadcasting tensor operations. 00:42:03.020 |
>> Yeah. Can you just clarify what you mean by chunking? 00:42:05.220 |
>> Yeah. So if I have a vector of let's say length 1024, right? 00:42:10.500 |
That is my embedding dimensions, the hidden dimension for my transformer. 00:42:13.700 |
And if I have keys of say dimension let's say two, make it easy. 00:42:19.980 |
Chunking just means that I'm going to split that vector of 1024 00:42:26.020 |
Right? So I'm literally just going to like reshape that vector and chunk them 00:42:38.060 |
Is this alone enough to define a transformer? 00:42:45.940 |
So it's good because like we get all of the parallelization advantages and all of kind 00:42:50.700 |
of the attention advantages that I talked about on the previous slide. 00:42:53.060 |
This is a slide from like a Justin Johnson and Dante Zhu who are both ex-Stanford alums now teaching 00:42:58.700 |
courses about transformers and deep learning at various colleges. 00:43:02.820 |
But you're missing kind of like one key component, right? 00:43:08.460 |
So if you just look at this and squint at this for like a little bit of time, 00:43:14.180 |
what you're missing is like, okay, so I am just taking different weighted averages 00:43:18.260 |
of the same underlying values over and over again. 00:43:24.940 |
Relative to the things that are coming out of my transformer, 00:43:27.300 |
there actually is no non-linearity, but just self-attention. 00:43:31.500 |
This is basically a glorified linear network. 00:43:36.180 |
So we need some way to fix that because that's where the expressivity, 00:43:40.740 |
the kind of the magic of deep learning happens. 00:43:42.420 |
It's kind of when we stack these non-linearities, go deeper and learn new patterns at scale. 00:43:53.220 |
We had an MLP to the very end of the transformer block and it's very simple, right? 00:43:57.180 |
It all it does is it kind of takes the embedding dimension that comes 00:43:59.820 |
out of the self-attention block that we just defined, 00:44:08.180 |
and then down projects it back to the embedding dimension. 00:44:11.020 |
Usually what you're going to see is a factor of four. 00:44:20.340 |
But here's some like soft intuition for kind of why this might work. 00:44:24.660 |
This kind of is a throwback to like 229, you know, OGML days. 00:44:30.060 |
So you want your network as a whole to be able to kind of 00:44:37.900 |
but also remember the things that are important, right? 00:44:41.500 |
the remembering that are important are these residual connections, 00:44:43.980 |
like the fact that I'm adding X to some transform of X. 00:44:49.860 |
It's basically saying like what stuff can I throw away and should I 00:44:53.140 |
basically forget because it's not really relevant to what I care about, 00:44:56.460 |
at the end of the day, which are good contextual representations. 00:44:59.100 |
And so the role of the MLP is very similar to the role of kind of 00:45:03.620 |
kernel and you know the good old support vector machine literature, right? 00:45:07.900 |
So if I have two classes that are kind of like this in a plane, 00:45:11.740 |
and I want to draw a line that partitions them, how do I do it? 00:45:21.300 |
if I just implicitly lift these things up to like 3D, 00:45:25.220 |
I can turn this into a surface in 3D that can just cut in half, 00:45:28.900 |
separates my stuff. So projecting up with this MLP is basically 00:45:32.700 |
this way of kind of like aligning or crystallizing the structure of our features, 00:45:40.780 |
I think we're out of time, so we're going to go through like 00:45:42.700 |
the rest of the transformer evolution in a bit. 00:45:46.540 |
I have office hours tomorrow and I will be back. Thanks.