back to indexStanford CS25: V4 I Jason Wei & Hyung Won Chung of OpenAI
00:00:07.200 |
So he's an AI researcher based in San Francisco, 00:00:11.560 |
He was previously a research scientist at Google Brain, 00:00:18.960 |
instruction tuning, as well as emergent phenomena. 00:00:23.380 |
and he's been here before to give some talks. 00:00:25.320 |
So we're very happy to have you back, Jason, and take it away. 00:00:44.060 |
and then we'll both take questions at the end. 00:00:49.440 |
So I want to talk about a few very basic things. 00:00:55.300 |
And I think the fundamental question that I hope to get at 00:01:14.780 |
to do that I found to be extremely helpful in trying 00:01:17.020 |
to answer this question is to use a tool, which 00:01:33.900 |
So in 2019, I was trying to build one of the first lung 00:01:39.260 |
So there'd be an image, and you have to say, OK, 00:01:55.300 |
And he said, Jason, you need a medical degree 00:02:05.680 |
So I basically looked at the specific type of lung cancer 00:02:20.140 |
And in the end, I learned how to do this task 00:02:25.340 |
And the result of this was I gained intuitions 00:02:32.620 |
OK, so first, I will do a quick review of language models. 00:03:05.880 |
it outputs a probability for every single word 00:03:10.000 |
So vocabulary would be A, aardvark, drink, study, 00:03:26.680 |
to put a probability over every single word here. 00:03:45.520 |
And then the way that you train the language model is you say, 00:03:59.040 |
I want this number here, 0.6, to be as close as possible to 1. 00:04:11.680 |
And you want this loss to be as low as possible. 00:04:14.040 |
OK, so the first intuition that I would encourage everyone 00:04:56.160 |
So when you train a language model on a large enough 00:05:10.560 |
you have a lot of sentences that you can learn from. 00:05:12.800 |
So for example, there might be some sentence, 00:05:19.440 |
that code should be higher probability than the word 00:05:27.160 |
So somewhere in your data set, there might be a sentence, 00:05:31.560 |
I went to the store to buy papaya, dragon fruit, 00:05:36.900 |
that the probability of durian should be higher than squirrel. 00:05:41.240 |
The language model will learn world knowledge. 00:05:43.420 |
So there will be some sentence on the internet that says, 00:05:48.160 |
and then the language model should learn that it should be 00:05:53.760 |
You can learn traditional NLP tasks, like sentiment analysis. 00:05:58.440 |
I was engaged on the edge of my seat the whole time. 00:06:05.480 |
The next word should probably be good and not bad. 00:06:10.200 |
And then finally, another example is translation. 00:06:30.360 |
Standing next to Iroh, Zuko pondered his destiny. 00:06:34.880 |
And then kitchen should be higher probability than store. 00:06:41.160 |
So you might have, like, some arithmetic exam answer 00:06:46.480 |
And then the language model looks at this and says, 00:06:48.600 |
OK, the next word should probably be 15 and not 11. 00:06:55.120 |
of tasks like this when you have a huge data set. 00:07:03.840 |
And these are sort of like very clean examples of tasks. 00:07:34.920 |
And then now, pretend you're the language model. 00:07:37.880 |
And you could say, like, OK, what's the next word here? 00:07:46.360 |
And so, like, OK, what's the language model learning 00:08:00.760 |
So here, the model is learning, like, basically 00:08:09.840 |
I think it's kind of hard to know, but the answer is A. 00:08:24.680 |
And this, I would say, I don't know what task this is. 00:08:31.040 |
This is, like, you know, it could have been woman. 00:08:41.960 |
is that the next word prediction task is really challenging. 00:08:45.520 |
So, like, if you do this over the entire database, 00:09:01.520 |
is scaling, which is, by the way, let's say scaling compute. 00:09:15.960 |
And by the way, compute is equal to how much data you have 00:09:47.200 |
I would encourage you guys to read the paper. 00:09:49.600 |
And what this basically says is you can have a plot here, 00:09:56.320 |
where the x-axis is compute and the y-axis is loss. 00:10:05.920 |
And what this intuition says is you can train one language 00:10:13.960 |
So you can train the next one, you'll have that loss. 00:10:16.480 |
If you train the one after that, you'll have that loss. 00:10:19.120 |
Then if you train the one after that, you'll have that loss. 00:10:22.680 |
And you can basically predict the loss of a language model 00:10:26.300 |
based on how much compute you're going to use to train it. 00:10:32.560 |
is that in this paper, they showed that the x-axis here 00:10:39.320 |
So basically, it would be surprising if the trend broke 00:10:50.200 |
Because if it went like that, then it would saturate. 00:10:55.120 |
And then putting more compute or training a larger language 00:11:00.440 |
So I think a question that we don't have a good answer 00:11:18.660 |
to as a field, but I'll give you a hand-wavy answer, 00:11:22.980 |
is why does scaling up the size of your language model 00:11:28.980 |
And I'll give two basically hand-wavy answers. 00:11:39.340 |
So one thing that's important is how good is your language model 00:11:50.780 |
and you see a bunch of facts on the internet. 00:11:53.660 |
You have to be pretty choosy in which facts you memorize. 00:11:57.220 |
Because if you don't have that many parameters, 00:11:59.140 |
you're like, oh, I can only memorize a million facts. 00:12:22.020 |
And then the other hand-wavy answer I'll give 00:12:25.460 |
is small language models tend to learn first-order heuristics. 00:12:35.260 |
you're already struggling to get the grammar correct. 00:12:37.980 |
You're not going to do your best to try to get the math 00:12:46.860 |
you have a lot of parameters in your forward pass. 00:12:50.100 |
And you can try to do really complicated things 00:13:47.540 |
By this, I mean if you take some corpus of data 00:13:53.580 |
and you compute the overall loss on every word in that data set. 00:13:58.140 |
The overall loss, because we know that next word prediction 00:14:04.060 |
can decompose this overall loss into the loss 00:14:08.700 |
So you have, I don't know, some small number times the loss 00:14:14.660 |
of, say, grammar plus some small number times 00:14:21.300 |
the loss of sentiment analysis plus some small number 00:14:47.500 |
And so you can basically write your overall loss 00:16:05.940 |
some part of that overall loss that scales like this. 00:16:13.220 |
So for example, you could say, if GPT 3.5 is there 00:16:25.780 |
like this for, say, doing math or harder tasks, 00:16:30.380 |
where the difference between GPT 3.5 and GPT 4 00:16:37.420 |
And it turns out you can look at a big set of tasks, which 00:16:58.900 |
There's this corpus called Big Bench, which has 200 tasks. 00:17:06.300 |
So you have-- this was 29% of tasks that were smooth. 00:17:17.380 |
So if I draw the scaling plot, compute is on the x-axis. 00:17:23.380 |
And then here we have accuracy instead of loss. 00:17:36.860 |
So if you have your scaling curve, it'll just all be 0. 00:17:52.580 |
gets worse as you increase the size of the language model. 00:17:57.180 |
And then I think this was 13% will be not correlated. 00:18:21.380 |
And what I mean by that is if you plot your compute 00:18:28.140 |
and accuracy, for a certain point, up to a certain point, 00:18:37.460 |
And then the accuracy suddenly starts to improve. 00:18:42.020 |
And so you can define an emergent ability basically 00:18:54.700 |
And then for large models, you have much better 00:19:06.660 |
is, let's say you had only trained the small language 00:19:10.880 |
You would have predicted that it would have been impossible 00:19:13.460 |
for the language model to ever perform the task. 00:19:16.580 |
But actually, when you train the larger model, 00:19:18.540 |
the language model does learn to perform the task. 00:19:39.140 |
is something called inverse scaling slash u-shaped scaling. 00:19:45.660 |
OK, so I'll give a tricky prompt to illustrate this. 00:19:55.620 |
Repeat after me, all that glisters is not glib. 00:20:23.020 |
And this is the prompt I give to the language model. 00:20:31.580 |
is glib, because you asked to repeat after me. 00:20:40.460 |
have a extra small language model, a small language model, 00:20:48.100 |
The performance for the extra small language model 00:20:55.580 |
The small language model is actually worse at this task. 00:21:14.380 |
And the answer is, you can decompose this prompt 00:21:26.220 |
into three subtasks that are basically being done. 00:21:30.620 |
So the first subtask is, can you repeat some text? 00:21:39.340 |
extra small, small, large, and then here is 100. 00:22:02.020 |
So the quote is supposed to be, all that glisters is not gold. 00:22:05.860 |
And so you can then plot, again, extra small, small, large. 00:22:14.740 |
Well, the small language model doesn't know the quote. 00:22:19.860 |
And then the extra small doesn't know the quote, 00:22:47.060 |
And you could say, OK, what's the performance 00:22:52.220 |
And the small model can't do it, or the extra small model 00:23:07.540 |
And then why does this explain this behavior here? 00:23:35.060 |
And then the large model, it can do all three. 00:23:44.740 |
And so that's how, if you look at the individual subtasks, 00:23:52.700 |
you can explain the behavior of some of these weird scaling 00:23:56.580 |
So I will conclude with one general takeaway, which 00:24:15.460 |
And the takeaway is to just plot scaling curves. 00:24:28.380 |
So let's say I do something for my research project. 00:24:32.340 |
I fine-tune a model on some number of examples. 00:24:48.460 |
Of not doing whatever my research project is. 00:24:56.660 |
The reason you want to plot a scaling curve for this 00:25:02.740 |
And you find out that the performance is actually here. 00:25:11.940 |
have to collect all the data to do your thing. 00:25:19.900 |
Another scenario would be if you plotted that point 00:25:23.860 |
and it was there, then your curve will look like this. 00:25:54.660 |
And I'm happy to take a few questions beforehand 00:26:23.900 |
how do you differentiate between good data and bad data? 00:26:27.820 |
The question is-- or the answer is, you don't really. 00:26:31.380 |
But you should by only training on good data. 00:26:37.620 |
and filter out some data if it's not from a reliable data source. 00:26:46.180 |
behind the intuition for one or two of the examples, 00:26:48.780 |
like emergent, or why tail knowledge starts to develop? 00:26:55.980 |
What do you mean, intuition behind intuition? 00:26:58.520 |
Intuitively, these concepts that you're seeing in the graphs. 00:27:19.780 |
--attention, in an intuitive sense, not in a-- 00:27:26.620 |
makes the language model better at memorizing tail knowledge 00:27:35.860 |
Yeah, I think it's definitely related to the size 00:27:41.620 |
you could encode probably a more complex function within that. 00:27:49.780 |
you could probably encode more facts about the world. 00:27:54.140 |
And then if you want to repeat a fact or retrieve something, 00:28:05.100 |
So when you were studying the 200-ish problems 00:28:07.980 |
in the big bench, you noticed that 22% were flat. 00:28:11.580 |
But there's a possibility that if you were to increase 00:28:18.300 |
you were looking at the 33% that turned out to be emergent, 00:28:21.300 |
did you notice anything about the loss in the flat portion 00:28:24.060 |
that suggested that it would eventually become emergent? 00:28:32.980 |
The question is, when I looked at all the emergent tasks, 00:28:36.980 |
was there anything that I noticed before the emergence 00:28:40.140 |
point in the loss that would have hinted that it 00:28:48.380 |
You can look at the loss, and it kind of gets better. 00:28:56.340 |
because you might not have all the intermediate points 00:29:16.580 |
are the biggest bottlenecks for current large language models? 00:29:19.380 |
Is it the quality of data, the amount of compute, 00:29:25.900 |
I guess if you go back to the scaling loss paradigm, what 00:29:32.860 |
it says is that if you increase the size of the data 00:29:40.540 |
And I think we'll probably try to keep increasing those things. 00:29:44.380 |
And then the last one, what are your thoughts on the paper, 00:29:49.220 |
Are emergent abilities of large language models a mirage? 00:29:55.640 |
I guess I would encourage you to read the paper 00:29:59.860 |
But I guess what the paper says is if you change the metric a 00:30:10.700 |
I think the language model abilities are real. 00:30:25.340 |
All right, so thanks, Jason, for the very insightful talk. 00:30:36.700 |
He has worked on various aspects of large language models, 00:30:40.780 |
things like pre-training, instruction fine-tuning, 00:30:43.380 |
reinforcement learning with human feedback, reasoning, 00:30:48.060 |
And some of his notable works include the scaling FLAN 00:30:50.820 |
papers, such as FLAN T5, as well as FLAN POM, and T5X, 00:30:55.220 |
the training framework used to train the POM language model. 00:31:06.220 |
All right, my name is Hyung Won, and really happy 00:31:30.500 |
I'm giving a lecture on transformers at Stanford. 00:31:35.380 |
And I thought, OK, some of you in this room and in Zoom 00:31:47.220 |
So that could be a good topic to think about. 00:31:50.140 |
And when we talk about something into the future, 00:31:53.620 |
the best place to get an advice is to look into the history. 00:31:57.860 |
And in particular, look at the early history of transformer 00:32:04.660 |
And the goal will be to develop a unified perspective in which 00:32:10.380 |
we can look into many seemingly disjoint events. 00:32:17.340 |
to project into the future what might be coming. 00:32:20.260 |
And so that will be the goal of this lecture. 00:32:30.020 |
Everyone I see is saying AI is so advancing so fast that 00:32:36.420 |
And it doesn't matter if you have years of experience. 00:32:39.340 |
There's so many things that are coming out every week 00:32:43.420 |
And I do see many people spend a lot of time and energy 00:32:46.940 |
catching up with the latest developments, the cutting 00:32:52.580 |
And then not enough attention goes into all things 00:32:54.980 |
because they become deprecated and no longer relevant. 00:33:00.740 |
But I think it's important, actually, to look into that. 00:33:05.460 |
when things are moving so fast beyond our ability 00:33:08.060 |
to catch up, what we need to do is study the change itself. 00:33:11.460 |
And that means we can look back at the previous things 00:33:16.660 |
and try to map how we got here and from which we can look 00:33:23.060 |
So what does it mean to study the change itself? 00:33:28.140 |
First, we need to identify the dominant driving 00:33:35.900 |
because typically, a change has many, many driving forces. 00:33:41.420 |
because we're not trying to get really accurate. 00:33:43.460 |
We just want to have the sense of directionality. 00:33:46.220 |
Second, we need to understand the driving force really well. 00:33:49.220 |
And then after that, we can predict the future trajectory 00:34:02.820 |
I think it's actually not that impossible to predict 00:34:06.660 |
some future trajectory of a very narrow scientific domain. 00:34:17.540 |
and then make your prediction accuracy from 1% to 10%. 00:34:25.500 |
Say, one of them will be really, really correct, 00:34:37.620 |
that you really have to be right a few times. 00:34:40.340 |
So if we think about why predicting the future 00:34:47.260 |
is difficult, or maybe even think about the extreme case 00:34:53.100 |
with perfect accuracy, almost perfect accuracy. 00:34:55.460 |
So here, I'm going to do a very simple experiment 00:34:58.500 |
of dropping this pen and follow this same three-step process. 00:35:04.100 |
So we're going to identify the dominant driving force. 00:35:07.300 |
First of all, what are the driving forces acting 00:35:12.300 |
We also have, say, air friction if I drop it. 00:35:17.380 |
And that will cause what's called a drag force acting 00:35:21.380 |
And actually, depending on how I drop this, the orientation, 00:35:25.620 |
the aerodynamic interaction will be so complicated 00:35:28.860 |
that we don't currently have any analytical way of modeling 00:35:32.700 |
We can do it with the CFD, the computational fluid dynamics, 00:35:38.780 |
This is heavy enough that gravity is probably 00:35:44.100 |
Second, do we understand this dominant driving force, which 00:35:47.780 |
And we do because we have this Newtonian mechanics, which 00:35:52.500 |
And then with that, we can predict the future trajectory 00:35:56.420 |
And if you remember from this dynamics class, 00:36:06.180 |
And then 1/2 gt square will give a precise trajectory 00:36:13.500 |
So if there is a single driving force that we really 00:36:17.140 |
understand, it's actually possible to predict 00:36:21.500 |
So then why do we really fear about predicting the future 00:36:30.660 |
the number of driving force, the sheer number 00:36:33.140 |
of dominant driving forces acting on the general prediction 00:36:41.140 |
that we cannot predict the most general sense. 00:36:47.420 |
X-axis, we have a number of dominant driving forces. 00:36:52.180 |
So on the left-hand side, we have a dropping a pen. 00:37:00.540 |
And then as you add more stuff, it just becomes impossible. 00:37:08.180 |
And you might think, OK, I see all the time things 00:37:15.220 |
And some people will come up with a new agent, new modality, 00:37:22.900 |
I'm not even able to catch up with the latest thing. 00:37:25.700 |
How can I even hope to predict the future of the AI research? 00:37:31.780 |
because there is a dominant driving force that 00:37:35.180 |
is governing a lot, if not all, of the AI research. 00:37:41.460 |
to point out that it's actually closer to the left 00:37:45.180 |
than to the right than we actually may perceive. 00:38:00.540 |
on the technical stuff, which you can probably 00:38:07.780 |
And for that, I want to share how my opinion is. 00:38:14.580 |
And by no means, I'm saying this is correct or not. 00:38:27.260 |
And on the y-axis, we have the calculations flopped. 00:38:31.620 |
If you pay $100, and how much computing power do you get? 00:38:37.500 |
And then x-axis, we have a time of more than 100 years. 00:38:45.100 |
And I don't know any trend that is as strong and as 00:38:51.300 |
So whenever I see this kind of thing, I should say, OK, 00:38:57.540 |
And better, I should try to leverage as much as possible. 00:39:01.700 |
And so what this means is you get 10x more compute 00:39:07.180 |
every five years if you spend the same amount of dollar. 00:39:10.460 |
And so in other words, you get the cost of compute 00:39:24.180 |
But that is, I think, really important to think about. 00:39:35.180 |
Let's think about the job of the AI researchers. 00:39:37.580 |
It is to teach machines how to think in a very general sense. 00:39:41.260 |
And one somewhat unfortunately common approach 00:39:45.220 |
is we think about how we teach machine how we think we think. 00:39:55.420 |
try to incorporate that into some kind of mathematical model 00:40:08.860 |
that we try to model something that we have no idea about. 00:40:11.860 |
And what happens if we go with this kind of approach 00:40:19.020 |
And so you can maybe get a paper or something. 00:40:23.780 |
because we don't know how this will limit further scaling up. 00:40:42.180 |
And bitter lesson is, I think, the single most important piece 00:40:48.660 |
And it says-- this is my wording, by the way-- 00:40:54.300 |
can be summarized into developing progressively more 00:40:57.980 |
general method with weaker modeling assumptions 00:41:01.020 |
or inductive biases, and add more data and compute-- 00:41:04.860 |
And that has been the recipe of entire AI research, 00:41:10.380 |
And if you think about this, the models of 2000 00:41:14.460 |
is a lot more difficult than what we use now. 00:41:17.900 |
And so it's much easier to get into AI nowadays 00:41:22.900 |
So this is, I think, really the key information. 00:41:27.820 |
We have this compute cost that's going down exponentially. 00:41:36.260 |
And just try to leverage that as much as possible. 00:41:38.740 |
And that is the driving force that I wanted to identify. 00:41:43.500 |
And I'm not saying this is the only driving force. 00:42:00.800 |
with more structure, more modeling assumptions, fancier 00:42:06.420 |
What you see is typically you start with a better performance 00:42:13.140 |
But it plateaus because of some kind of structure backfiring. 00:42:17.460 |
because we give a lot more freedom to the model, 00:42:21.380 |
But then as we add more compute, it starts working. 00:42:40.220 |
This red one here, it will pick up a lot later 00:42:49.500 |
We cannot indefinitely wait for the most general case. 00:42:54.940 |
where our compute situation is at this dotted line. 00:42:57.940 |
If we're here, we should choose this less structure 00:43:01.060 |
one as opposed to this even less structure one, 00:43:14.260 |
And so the difference between these two method 00:43:30.340 |
algorithmic development, and architecture that we have, 00:43:39.580 |
And that has been really how we have made so much progress. 00:43:48.020 |
when we have more compute, better algorithm, or whatever. 00:43:51.780 |
And as a community, we do adding structure very well. 00:43:58.140 |
with like papers, you add a nice one, then you get a paper. 00:44:01.620 |
But removing that doesn't really get you much. 00:44:06.260 |
And I think we should do a lot more of those. 00:44:08.700 |
So maybe another implication of this bitter lesson 00:44:11.820 |
is that because of this, what is better in the long term 00:44:35.220 |
it's more chaotic at the beginning, so it doesn't work. 00:44:41.980 |
we can put in more compute and then it can be better. 00:44:44.940 |
So it's really important to have this in mind. 00:44:51.060 |
this dominant driving force behind the AI research. 00:45:04.420 |
the next step is to understand this driving force better. 00:45:11.980 |
And for that, we need to go back to some history 00:45:15.540 |
of transformer, 'cause this is a transformers class, 00:45:21.180 |
that were made by the researchers at the time 00:45:33.580 |
And we'll go through some of the practice of this. 00:45:42.020 |
So now we'll go into a little bit of the technical stuff. 00:45:45.700 |
Transformer architecture, there are some variants. 00:46:02.540 |
which you can think of as a current like GPT-3 00:46:07.060 |
This has a lot less structure than the encoder decoder. 00:46:09.940 |
So these are the three types we'll go into detail. 00:46:12.820 |
Second, the encoder only is actually not that useful 00:46:21.820 |
and then spend most of the time comparing one and three. 00:46:29.700 |
So first of all, let's think about what a transformer is. 00:46:32.620 |
Just at a very high level or first principles, 00:46:38.820 |
And sequence model has an input of a sequence. 00:46:42.380 |
So sequence of elements can be words or images or whatever. 00:46:49.180 |
In this particular example, I'll show you with the words. 00:46:55.980 |
'cause we have to represent this word in computers, 00:47:00.540 |
which requires just some kind of a encoding scheme. 00:47:04.380 |
So we just do it with a fixed number of integers 00:47:12.820 |
is to represent each sequence element as a vector, 00:47:15.780 |
dense vector, because we know how to multiply them well. 00:47:21.420 |
And finally, this sequence model will do the following. 00:47:30.060 |
And we do that by let them take the dot product 00:47:35.660 |
we can say semantically they are more related 00:47:41.740 |
And the transformer is a particular type of sequence model 00:47:50.620 |
So let's get into the details of this encoder decoder, 00:47:57.180 |
So let's go into a little bit, a piece at a time. 00:48:03.860 |
of machine translation, which used to be very cool thing. 00:48:08.100 |
And so you have an English sentence that is good, 00:48:13.740 |
So first thing is to encode this into a dense vector. 00:48:22.860 |
And then we have to let them take the dot product. 00:48:25.180 |
So this lines represent which element can talk 00:48:33.940 |
we take what is called the bidirectional attention. 00:48:38.460 |
And then we have this MLP or feed forward layer, 00:48:44.220 |
You just do some multiplication just because we can do it. 00:48:49.380 |
And then that's one layer, and we repeat that n times. 00:48:55.540 |
And at the end, what you get is the sequence of vectors, 00:49:11.260 |
So here we put in as an input what the answer should be. 00:49:19.500 |
and then das ist gut, I don't know how to pronounce it, 00:49:21.420 |
but that's the German translation of that is good. 00:49:23.860 |
And so we kind of go through the similar process. 00:49:37.980 |
So we cannot, when we train it, we should limit that. 00:49:47.420 |
So after this, you can get after, again, N layers, 00:49:58.880 |
this is a general encoder-decoder architecture. 00:50:07.020 |
Now I'll point out some important attention patterns. 00:50:19.380 |
That is done by this cross-attention mechanism 00:50:22.660 |
which is just that each vector's representation 00:50:28.220 |
should attend to some of them in the encoder. 00:50:33.700 |
which is interesting is that all the layers in the decoder 00:50:37.520 |
attend to the final layer output of the encoder. 00:50:40.600 |
I will come back to the implication of this design. 00:50:46.580 |
And now, move on to the second type of architecture, 00:51:06.740 |
And that is, that represent the input sequence. 00:51:13.100 |
And then, let's say we do some kind of a sentiment analysis. 00:51:23.140 |
And that's required for all these task-specific cases. 00:51:43.560 |
This was how the field really advanced at the time. 00:51:56.540 |
that was put into this particular architecture 00:51:58.820 |
is that we're gonna give up on the generation. 00:52:02.540 |
If we do that, it becomes a lot simpler problem. 00:52:07.140 |
we're talking about sequence to classification labels, 00:52:17.540 |
was like, we sometimes call it BERT engineers. 00:52:27.660 |
And, but if we look at from this perspective, 00:52:38.780 |
but in the long term, it's not really useful. 00:53:01.620 |
And so there's misconception that some people think 00:53:09.000 |
so it cannot be used for supervised learning. 00:53:12.220 |
The trick is to have this input that is good, 00:53:17.020 |
And if you do that, then it just becomes simple, 00:53:21.620 |
So what we do is the self-attention mechanism here 00:53:24.840 |
is actually handling both the cross-attention 00:53:29.380 |
and self-attention sequence learning within each. 00:53:34.660 |
And then, as I mentioned, the output is a sequence. 00:53:38.780 |
And then the key design features are self-attention 00:53:41.500 |
is serving both roles, and we are, in some sense, 00:53:45.380 |
sharing the parameters between input and target. 00:53:56.140 |
So I think there are many, they look very different, 00:54:03.440 |
And I argue that they're actually quite similar. 00:54:07.820 |
And so to illustrate that, we're gonna transform, 00:54:21.640 |
those additional structures, are they relevant nowadays? 00:54:24.340 |
Now that we have more compute, better algorithm, and so on. 00:54:32.500 |
And then, as we go through, we'll populate this table. 00:54:36.300 |
So let's first look at this additional cross-attention. 00:54:42.060 |
is an encoder-decoder, which has this additional red block, 00:54:44.760 |
the cross-attention, compared to the simpler one 00:54:47.900 |
So we wanna make the left closer to the right. 00:54:51.700 |
So that means we need to either get rid of it, or something. 00:55:01.820 |
actually have the same number of parameters, same shape. 00:55:05.100 |
So that's the first step, share both of these. 00:55:07.260 |
And then it becomes mostly the same mechanism. 00:55:22.300 |
encoder-decoder architecture uses the separate parameters. 00:55:38.400 |
Third difference is the target-to-input attention pattern. 00:55:41.500 |
So we need to connect the target to the input, 00:55:46.140 |
In the encoder-decoder case, we had this cross-attention, 00:56:00.580 |
attending to the final layer output of the encoder. 00:56:11.420 |
we are looking at the same layer representation 00:56:20.940 |
we have to bring back this attention to each layer. 00:56:24.420 |
So now layer one will be attending to layer one of this. 00:56:28.460 |
And finally, the last difference is the input attention. 00:56:33.380 |
I mentioned about this bidirectional attention, 00:56:50.260 |
these two architectures are almost identical. 00:56:53.440 |
There's a little bit of difference in the cross attention, 00:57:00.180 |
these two architecture in the same task, same data, 00:57:02.500 |
I think you will get pretty much within the noise, 00:57:04.300 |
probably closer than if you train the same thing twice. 00:57:11.940 |
Now we'll look at what are the additional structures, 00:57:21.140 |
And then, so we can say that encoder-decoder, 00:57:26.060 |
has these additional structures in the devices built in. 00:57:34.820 |
what encoder-decoder tries at it as a structure 00:57:42.260 |
it'll be useful to use a separate parameters. 00:57:54.100 |
Back when the transform was introduced in 2017, 00:58:10.380 |
So in that task, we have this input and target 00:58:30.540 |
Modern language models is about learning knowledge. 00:58:42.860 |
So does it make sense to have a separate parameter 00:58:55.440 |
And if we represent them in separate parameters, 00:59:14.040 |
and with Jason, we did this instruction fine-tuning work. 00:59:17.720 |
And what this is, is you take the pre-trained model, 00:59:21.080 |
and then just fine-tune on academic data set, 00:59:28.960 |
but here, let's think about the performance gain 00:59:38.800 |
which is T5-based, which is encoder-decoder architecture. 00:59:52.440 |
And then at the end, we just spent three days on T5. 00:59:55.180 |
But the performance gain was a lot higher on this. 01:00:07.000 |
So my hypothesis is that it's about the length. 01:00:12.000 |
So academic data sets we use, we use like 1,832 tasks, 01:00:16.360 |
and here, they have this very distinctive characteristic 01:00:22.280 |
long in order to make the task more difficult, 01:00:26.400 |
because if we do, there's no way to grade it. 01:00:31.560 |
So what happens is you have a long text of input, 01:00:36.060 |
And so this is kind of the length distribution 01:00:49.280 |
and a very different type of sequence going into the target. 01:00:54.720 |
has an assumption that they will be very different. 01:00:57.000 |
That structure really shines because of this. 01:01:03.680 |
why this really architecture was just suitable 01:01:26.960 |
doesn't mean that we are not interested in them. 01:01:29.160 |
Actually, if anything, we are more interested in that. 01:01:31.180 |
So now, we have this longer target situation. 01:01:39.040 |
And moreover, we think about this chat application, 01:01:59.980 |
So that was the first inductive bias we just mentioned. 01:02:05.740 |
target element can only attend to the fully encoded ones, 01:02:11.620 |
Let's look at this additional structure, what that means. 01:02:28.260 |
Meaning that, for example, in computer vision, 01:02:30.620 |
lower layer, bottom layers encode something like edges, 01:02:35.620 |
combining the features, something like cat face. 01:02:39.700 |
a hierarchical representation learning method. 01:02:45.460 |
if decoder layer one attends to encoder final layer, 01:02:50.300 |
which probably has a very different level of information, 01:02:53.060 |
is that some kind of an information bottleneck, 01:02:55.340 |
which actually motivated the original attention mechanism. 01:02:59.540 |
And in practice, I would say, in my experience, 01:03:04.240 |
And that's because my experience was limited to, 01:03:12.740 |
But what if we have 10x or 1000x more layers? 01:03:25.200 |
Final structure we're gonna talk about is the, 01:03:29.060 |
when we do this, there's like a bidirectional thing 01:03:44.700 |
2018, when we were solving that question answering squad, 01:03:56.140 |
like I think maybe boosting up the squad score by like 20. 01:04:01.260 |
But at scale, I don't think this matters that much. 01:04:07.060 |
So we did, in flan two, we tried both bidirectional 01:04:14.380 |
So, but I wanna point out this bidirectionality, 01:04:23.180 |
So at every turn, the new input has to be encoded again, 01:04:27.240 |
and for unidirectional attention is much, much better. 01:04:31.140 |
So let's think about this more modern conversation 01:04:34.360 |
between user and assistant, how are you bad and why? 01:04:38.220 |
And so here, if we think about the bidirectional case, 01:04:57.700 |
so we need to do everything from scratch again. 01:05:05.560 |
because now when we are trying to generate why, 01:05:09.780 |
because we cannot attend to the future tokens, 01:05:15.140 |
So if you see the difference, this part can be cached, 01:05:27.300 |
So I would say bidirectional attention did well in 2018, 01:05:33.340 |
and now because of this engineering challenge, 01:05:37.140 |
So to conclude, we have looked into this driving force, 01:05:41.580 |
dominant driving force governing this AI research, 01:05:44.340 |
and that was this exponentially cheaper compute 01:05:51.800 |
we analyzed some of the additional structures 01:05:54.180 |
added to the encoder-decoder compared to decoder-only, 01:06:02.120 |
And I wanted to just conclude with this remark. 01:06:08.580 |
which are all, one can say this is just historical artifacts 01:06:12.500 |
and doesn't matter, but if you do many of these, 01:06:17.600 |
You can hopefully think about those in a more unified manner 01:06:21.180 |
and then see, okay, what assumptions in my problem 01:06:25.240 |
that I need to revisit, and are they relevant? 01:06:30.260 |
Can we do it with a more general thing and scale up? 01:06:37.280 |
and together we can really shape the future of AI 01:07:11.560 |
then how long do you think the mix of experts 01:07:15.640 |
is gonna stay for the new large-length models? 01:07:20.640 |
- So one thing I have to apologize is the architecture 01:07:24.120 |
is kind of a thing that I'm not really comfortable 01:07:51.100 |
like the parameter sharing and the bidirectional attention, 01:07:54.860 |
can they not be interpreted as less structure, 01:08:23.640 |
if we have enough capacity, we can just handle both. 01:08:34.860 |
oh, actually, maybe I should have repeated the question. 01:08:36.540 |
The question is, can we think about this parameter sharing 01:08:43.700 |
But I think it's a little bit more complicated model, 01:09:01.820 |
- Do you have any thoughts on the recent state-space models 01:09:06.820 |
like Mamba and how that fits into the paradigm 01:09:22.460 |
It's hard to, like, think about it on the spot, 01:09:29.580 |
but I don't, like, architecture is, like, kind of a, 01:09:43.620 |
might become a bottleneck when we think about that. 01:09:50.540 |
So I think it's, transformers have done a good job. 01:09:58.740 |
- So, like, for cross-attention and casual attention, 01:10:05.020 |
it's, like, imploding permutation of invariance 01:10:08.540 |
in a way for multi-attention instead of causal. 01:10:17.220 |
for invariances for self-sufficient learning. 01:10:21.140 |
in terms of complexity that you just talked about? 01:10:26.300 |
versus the, like, the bidirectional attention. 01:10:34.180 |
that being able to attend to the future part of it 01:10:41.660 |
- Also, like, one of the, like, causal attention 01:10:46.340 |
removed the, like, invariance for permutation. 01:10:53.540 |
there's a lot of invariances, right, for augmentation. 01:11:03.180 |
I don't really like this invariances and all these. 01:11:05.860 |
These are, like, how humans think we perceive the vision. 01:11:09.620 |
Like, CNN, for example, is, like, translation invariance, 01:11:20.340 |
And so the machines might be learning the vision 01:11:23.260 |
in a completely different way from how humans do, 01:11:51.620 |
If we do it without the structure, it's actually better. 01:12:07.100 |
So I'm just curious, what are some big inductive biases 01:12:13.700 |
big blocks that we can release, or let go of? 01:12:18.060 |
That would be one question, if you could let me go. 01:12:21.140 |
- The current structure that we should get rid of. 01:12:26.340 |
'cause clearly you've been thinking about this, right? 01:12:38.500 |
- Yeah, so when I think about this as an architecture, 01:12:50.740 |
and at the end, we published this paper called, 01:13:11.340 |
the architecture is not the bottleneck in further scaling. 01:13:17.060 |
especially on the supervised learning paradigm, 01:13:22.260 |
What we're doing with this maximum likelihood estimation is, 01:13:25.220 |
okay, given this, this is the only correct target, 01:13:48.620 |
But now, if you're thinking about very general, 01:13:54.420 |
and then you say this is the only correct answer, 01:13:57.540 |
I think that the implication that could be really severe. 01:14:07.260 |
is one instantiation of not using this maximum likelihood, 01:14:17.860 |
RLHF itself is not really that scalable, I would say. 01:14:23.860 |
this supervised deep learning to train a model 01:14:46.860 |
being the exponentially cheap compute, right? 01:14:53.460 |
and we're going towards performance-oriented architecture. 01:14:56.780 |
So can we rely then on, 'cause in the past 50 years, 01:15:01.480 |
we had transistors doubling or whatever at Moore's Law. 01:15:13.500 |
about how that's gonna project into the future. 01:15:22.780 |
I think what matters is the compute availability. 01:15:28.660 |
and that enabled the continuation of this trend. 01:15:35.460 |
with low-precision thing, which still, I think, is cool. 01:15:39.820 |
But I think there are many other GPU-level things. 01:15:43.400 |
But also, if we are kind of sure about the architecture, 01:15:56.120 |
But GPU, if you think about it, is too general. 01:16:06.300 |
But maybe other things will come as a bottleneck, 01:16:21.320 |
we're talking about exponential driving forces, right? 01:16:23.680 |
You can tell me that you wanna hard-code chips, 01:16:29.800 |
into like paradise or wherever the hell we're going. 01:16:36.280 |
I think we just need to do a little bit better, 01:16:39.200 |
and at some point, the machines will be better than us 01:16:47.080 |
but if we look back at, say, this video two years from now, 01:16:50.440 |
I think it'll be less, more of a joke, serious thing. 01:16:55.960 |
- All right, so thanks to Hongwan for an amazing talk.