back to index

Stanford CS25: V4 I Hyung Won Chung of OpenAI


Chapters

0:0 Introduction
2:5 Identifying and understanding the dominant driving force behind AI.
15:18 Overview of Transformer architectures: encoder-decoder, encoder-only and decoder-only
23:29 Differences between encoder-decoder and decoder-only, and rationale for encoder-decoder’s additional structures from the perspective of scaling.

Whisper Transcript | Transcript Only Page

00:00:00.000 | Now, we'll have Hyung-Won give a talk.
00:00:07.640 | So, he's currently a research scientist
00:00:10.320 | on the OpenAI Chat GPT team.
00:00:12.760 | He has worked on various aspects of large language models.
00:00:16.520 | Things like pre-training, instruction fine-tuning,
00:00:19.600 | reinforcement learning with human feedback,
00:00:21.840 | reasoning, and so forth.
00:00:24.280 | Some of his notable works include
00:00:26.160 | the scaling Flan papers such as Flan T5,
00:00:28.840 | as well as Flan POM, and T5X.
00:00:31.520 | The training framework used to train the POM language model.
00:00:34.840 | Before OpenAI, he was at Google Brain,
00:00:37.520 | and he received his PhD from MIT.
00:00:40.600 | So, give a hand for Hyung-Won.
00:00:43.600 | >> All right. My name is Hyung-Won,
00:00:56.720 | and really happy to be here today.
00:00:58.720 | This week, I was thinking about,
00:01:00.760 | by the way, is my mic working fine?
00:01:03.480 | Yeah. So, this week, I thought about,
00:01:06.600 | I'm giving a lecture on transformers at Stanford.
00:01:10.240 | What should I talk about? I thought,
00:01:12.600 | okay, some of you in this room and in
00:01:15.480 | Zoom will actually go shape the future of AI.
00:01:18.720 | So, maybe I should talk about that.
00:01:19.920 | It's a really important goal and ambitious,
00:01:21.840 | and we really have to get it right.
00:01:23.440 | So, that could be a good topic to think about.
00:01:26.440 | When we talk about something into the future,
00:01:29.840 | the best place to get an advice is to look into the history.
00:01:34.080 | In particular, look at the early history of
00:01:37.240 | transformer and try to learn many lessons from there.
00:01:41.120 | The goal would be to develop
00:01:43.440 | a unified perspective in which we can look
00:01:47.240 | into many seemingly disjoint events.
00:01:50.640 | From that, we can probably
00:01:53.280 | hope to project into the future what might be coming.
00:01:56.200 | So, that will be the goal of this lecture,
00:02:01.000 | and we'll look at some of the architectures of the transformers.
00:02:03.680 | So, let's get started. Everyone, I see,
00:02:08.320 | it's saying AI is so advancing so fast that it's so hard to keep up.
00:02:12.760 | It doesn't matter if you have years of experience,
00:02:15.640 | there's so many things are coming out
00:02:17.120 | every week that it's just hard to keep up.
00:02:19.680 | I do see many people spend a lot of time and
00:02:22.520 | energy catching up with this latest developments,
00:02:25.520 | the cutting edge, and the newest thing,
00:02:28.800 | and then not enough attention goes into all things because they
00:02:31.840 | become deprecated and no longer relevant.
00:02:36.920 | But I think it's important actually to look into that,
00:02:39.640 | because we really need to,
00:02:41.720 | when things are moving so fast beyond our ability to catch up,
00:02:45.160 | what we need to do is study the change itself,
00:02:47.680 | and that means we can look back at
00:02:49.360 | the previous things and then look at
00:02:52.200 | the current thing and try to map how we got here,
00:02:55.320 | and from which we can look into where we are heading towards.
00:02:59.200 | So, what does it mean to study the change itself?
00:03:04.280 | First, we need to identify
00:03:07.280 | the dominant driving forces behind the change.
00:03:09.800 | So, here dominant is an important word,
00:03:12.200 | because typically a change has many,
00:03:14.880 | many driving forces and we only care about
00:03:17.320 | the dominant one because we're not trying to get really accurate,
00:03:19.840 | you just want to have the sense of directionality.
00:03:22.320 | Second, we need to understand the driving force really well,
00:03:25.400 | and then after that we can predict the future trajectory
00:03:28.400 | by rolling out the driving force and so on.
00:03:31.840 | You heard it right, I'd mentioned about predicting the future.
00:03:34.880 | This is a computer science class,
00:03:36.200 | not like an astrology or something.
00:03:38.040 | But we do, I think it's actually not that impossible
00:03:41.760 | to predict some future trajectory
00:03:44.520 | of a very narrow scientific domain,
00:03:46.600 | and that endeavor is really useful to do,
00:03:50.560 | because let's say you do all these
00:03:53.800 | and then make your prediction accuracy from one percent to 10 percent,
00:03:57.800 | and then you'll make 100 predictions,
00:04:00.000 | 10 of them will be correct,
00:04:01.520 | say one of them will be really, really correct,
00:04:04.080 | meaning it will have an outside impact that outweighs everything,
00:04:07.600 | and I think that is how many I've seen,
00:04:11.800 | a very general thing in life,
00:04:13.720 | that you really have to be right a few times.
00:04:16.920 | So, if we think about why predicting the future is difficult,
00:04:24.320 | or maybe even think about the extreme case
00:04:26.280 | where we can all do the prediction with perfect accuracy,
00:04:30.200 | almost perfect accuracy.
00:04:31.440 | So here I'm going to do a very simple experiment
00:04:34.520 | of dropping this pen and follow this same three-step process.
00:04:40.080 | So we're going to identify the dominant driving force.
00:04:43.360 | First of all, what are the driving forces acting on this pen?
00:04:46.200 | Gravity downwards, and is that all?
00:04:48.520 | We also have, say, air friction if I drop it,
00:04:53.520 | and that will cause what's called a drag force acting upwards,
00:04:57.480 | and actually, depending on how I drop this, the orientation,
00:05:01.760 | the aerodynamic interaction will be so complicated
00:05:04.960 | that we don't currently have any analytical way of modeling that.
00:05:08.880 | We can do it with the CFD, the computational fluid dynamics,
00:05:11.760 | but it will be non-trivial.
00:05:12.920 | So we can neglect that.
00:05:15.040 | This is heavy enough that gravity is probably the only dominant force.
00:05:17.840 | So we simplify the problem.
00:05:20.360 | Second, do we understand this dominant driving force, which is gravity?
00:05:24.040 | And we do because we have this Newtonian mechanics
00:05:26.720 | which provides a reasonably good model.
00:05:28.720 | And then with that, we can predict the future trajectory of this pen.
00:05:32.640 | And if you remember from this dynamics class,
00:05:35.840 | if we have this initial velocity is zero,
00:05:38.640 | I'm not going to put any velocity,
00:05:40.360 | and then let's say position is zero here,
00:05:42.440 | and then 1/2 gt squared will give a precise trajectory of this pen
00:05:48.480 | as I drop this.
00:05:49.760 | So if there is a single driving force that we really understand,
00:05:54.160 | it's actually possible to predict what's going to happen.
00:05:57.640 | So then why do we really fear about predicting the future
00:06:02.640 | in the most general sense?
00:06:04.120 | And I argue that among many reasons,
00:06:06.880 | the sheer number of dominant driving forces
00:06:10.760 | acting on the general prediction is so complicated,
00:06:14.760 | and their interaction creates a complexity
00:06:17.360 | that we cannot predict in the most general sense.
00:06:19.800 | So here's my cartoon way of thinking about the prediction of future.
00:06:23.600 | X-axis, we have a number of dominant driving forces.
00:06:26.240 | Y-axis, we have a prediction difficulty.
00:06:28.440 | So on the left-hand side, we have a dropping a pen.
00:06:30.920 | It's a very simple case.
00:06:32.080 | The difficulty is very small.
00:06:34.000 | You just need to learn physics.
00:06:36.720 | And then as you add more stuff, it just becomes impossible.
00:06:41.400 | So how does this fit into the AI research?
00:06:44.320 | And you might think, "OK, I see all the time things are coming in,
00:06:49.200 | "and we are bombarded by new things,
00:06:51.360 | "and some people will come up with a new agent,
00:06:53.560 | "new modality, new MML use score, whatever.
00:06:56.200 | "We just see so many things.
00:06:58.040 | "It's just I'm not even able to catch up with the latest thing.
00:07:01.880 | "How can I even hope to predict the future of the AI research?"
00:07:05.520 | But I argue that it's actually simpler
00:07:07.920 | because there is a dominant driving force
00:07:11.240 | that is governing a lot, if not all, of the AI research.
00:07:15.560 | And because of that, I would like to point out
00:07:18.720 | that it's actually closer to the left than to the right
00:07:22.280 | than we actually may perceive.
00:07:25.320 | So what is that driving force?
00:07:28.400 | Oh, maybe before that, I would like to caveat that
00:07:31.960 | when I do this kind of talk,
00:07:34.600 | I would like to not focus too much on the technical stuff,
00:07:37.760 | which you can probably do better in your own time,
00:07:41.160 | but rather I want to share how I think.
00:07:43.920 | And for that, I want to share how my opinion is,
00:07:48.000 | and so it will be very strongly opinionated.
00:07:50.720 | And by no means I'm saying this is correct or not.
00:07:53.920 | I just wanted to share my perspective.
00:07:55.960 | So coming back to this driving force for AI,
00:07:58.320 | what is that dominant driving force?
00:08:00.680 | And here's a plot from Rich Sutton,
00:08:03.400 | and on the y-axis, we have the calculations flopped.
00:08:07.760 | If you pay $100, and how much computing power do you get?
00:08:12.120 | And it's in log scale.
00:08:13.640 | And then x-axis, we have a time of more than 100 years.
00:08:18.280 | So this is actually more than exponential,
00:08:21.200 | and I don't know any trend that is as strong
00:08:25.200 | and as long-lasting as this one.
00:08:27.440 | So whenever I see this kind of thing,
00:08:30.800 | I should say, okay, I should not compete with this,
00:08:33.840 | and better, I should try to leverage as much as possible.
00:08:37.920 | And so what this means is you get 10x more compute
00:08:42.920 | every five years if you spend the same amount of dollar.
00:08:46.760 | And so in other words, you get the cost of compute
00:08:50.960 | is going down exponentially.
00:08:52.560 | And this associated scaling
00:08:55.600 | is really dominating the AI research,
00:08:58.520 | and that is somewhat hard to take,
00:09:00.560 | but that is, I think, really important to think about.
00:09:04.000 | So coming back to this AI research,
00:09:05.960 | how is this exponentially cheaper compute
00:09:09.360 | drive the AI research?
00:09:11.560 | Let's think about the job of the AI researchers.
00:09:13.960 | It is to teach machines how to think
00:09:16.040 | in a very general sense.
00:09:17.640 | And one somewhat unfortunately common approach
00:09:21.640 | is we think about how we teach machine
00:09:24.880 | how we think we think.
00:09:27.280 | So meaning we model how we think,
00:09:31.360 | and then try to incorporate that
00:09:32.800 | into some kind of mathematical model, teach that.
00:09:35.760 | And now the question is, do we understand how we think
00:09:38.880 | at the very low level?
00:09:40.200 | I don't think we do.
00:09:41.600 | I have no idea what's going on.
00:09:43.240 | So it's fundamentally flawed in the sense
00:09:45.280 | that we're trying to model something
00:09:46.320 | that we have no idea about.
00:09:48.240 | And what happens if we go with this kind of approach
00:09:50.880 | is that it poses a structure
00:09:52.840 | that serves as a shortcut in the short term.
00:09:55.280 | And so you can maybe get a paper or something,
00:09:57.800 | but then it becomes a bottleneck
00:10:00.200 | because we don't know how this will limit
00:10:03.360 | further scaling up.
00:10:05.360 | More fundamentally, what this is doing
00:10:07.280 | is we are limiting the degree of freedom
00:10:10.080 | we are giving to the machines,
00:10:12.000 | and that will backfire at some point.
00:10:13.960 | And this has been going on for decades.
00:10:18.560 | And bitter lesson is I think the single most important
00:10:23.080 | piece of writing in AI,
00:10:25.000 | and it says, this is my wording, by the way,
00:10:28.040 | past 70 years of entire AI research
00:10:30.640 | can be summarized into developing
00:10:33.480 | progressively more general method
00:10:35.200 | with weaker modeling assumptions or inductive biases
00:10:38.680 | and add more data and compute, in other words, scale up.
00:10:41.200 | And that has been the recipe of entire AI research,
00:10:45.160 | not fancy things.
00:10:46.720 | And if you think about this,
00:10:48.280 | the models of 2000 is a lot more difficult
00:10:52.360 | than what we use now.
00:10:54.240 | And so it's much easier to get into AI nowadays
00:10:57.440 | from technical perspective.
00:10:59.120 | So this is, I think, really the key information.
00:11:04.120 | We have this compute cost is going down exponentially,
00:11:07.320 | and it's getting cheaper faster
00:11:09.040 | than we're becoming a better researcher.
00:11:11.120 | So don't compete with that
00:11:12.600 | and just try to leverage that as much as possible.
00:11:15.080 | And that is the driving force that I wanted to identify.
00:11:19.840 | And I'm not saying this is the only driving force,
00:11:22.640 | but this is the dominant driving force.
00:11:24.200 | So we can probably neglect the other ones.
00:11:26.480 | So here's a graphical version of that.
00:11:28.360 | X-axis, we have a compute,
00:11:29.760 | Y-axis, we have a performance of some kind.
00:11:31.840 | Let's think about some general intelligence.
00:11:34.160 | And let's look at two different methods.
00:11:36.640 | One with more structure,
00:11:38.440 | more modeling assumptions, fancier math, whatever.
00:11:41.040 | And then the other one is a less structure.
00:11:42.800 | What you see is typically,
00:11:44.520 | you start with a better performance
00:11:47.000 | when you have a low compute regime.
00:11:49.080 | And then, but it plateaus
00:11:50.320 | because of some kind of structure backfiring.
00:11:52.720 | And then with the less structure,
00:11:53.840 | because we give a lot more freedom to the model,
00:11:56.120 | it doesn't work in the beginning.
00:11:57.760 | But then as we add more compute, it starts working.
00:12:00.320 | And then it gets better.
00:12:01.920 | We call this more scalable methods.
00:12:04.840 | So does that mean we should just go with the least structure,
00:12:09.200 | most freedom to the model possible way from the get-go?
00:12:12.760 | And the answer is obviously no.
00:12:14.560 | Let's think about even less structure case.
00:12:16.600 | This red one here is, it will pick up a lot later
00:12:20.680 | and requires a lot more compute.
00:12:23.160 | So it really depends on where we are.
00:12:25.880 | We cannot indefinitely wait for the most general case.
00:12:29.520 | And so let's think about the case
00:12:31.320 | where our compute situation is at this dotted line.
00:12:34.320 | If we're here, we should choose this less structure one
00:12:37.960 | as opposed to this even less structure one,
00:12:40.320 | because the other one doesn't really work
00:12:42.040 | and the other one works.
00:12:43.600 | But crucially, we need to remember
00:12:45.200 | that we are adding some structure
00:12:47.280 | because we don't have compute.
00:12:48.440 | So we need to remove that later.
00:12:50.640 | And so the difference between these two method
00:12:53.120 | is that additional inductive biases or structure
00:12:56.080 | we impose, someone impose,
00:12:58.040 | that typically don't get removed.
00:13:00.640 | So adding this, what that means is that
00:13:03.800 | at the given level of compute data,
00:13:06.720 | algorithmic development and architecture that we have,
00:13:09.800 | there's like an optimal inductive bias or structure
00:13:12.680 | that we can add to the problem to make the progress.
00:13:15.960 | And that has been really how we have made so much progress.
00:13:19.480 | But these are like shortcuts
00:13:20.720 | that hinder further scaling later on.
00:13:22.760 | So we have to remove them later on
00:13:24.400 | when we have more compute, better algorithm or whatever.
00:13:28.160 | And as a community, we do adding structure very well.
00:13:32.160 | And 'cause there's an incentive structure with like papers,
00:13:35.400 | you add a nice one, then you get a paper,
00:13:38.000 | but removing that doesn't really get you much.
00:13:40.720 | So that we don't really do that.
00:13:42.640 | And I think we should do a lot more of those.
00:13:45.080 | So maybe another implication of this bitter lesson
00:13:48.160 | is that because of this,
00:13:50.280 | what is better in the long-term
00:13:52.320 | almost necessarily looks worse now.
00:13:55.720 | And this is quite unique to AI research
00:13:58.240 | because the AI research of current paradigm
00:14:01.920 | is learning-based method,
00:14:03.520 | meaning that we are giving models freedom,
00:14:07.000 | the machines choose how they learn.
00:14:09.200 | So because we need to give more freedom,
00:14:11.600 | it's more chaotic at the beginning, so it doesn't work.
00:14:16.480 | But then when it started working,
00:14:18.360 | we can put in more compute and then it can be better.
00:14:21.320 | So it's really important to have this in mind.
00:14:24.400 | So to summarize, we have identified
00:14:27.440 | this dominant driving force behind the AI research.
00:14:30.560 | And that is exponentially cheaper compute
00:14:33.120 | and associated scaling up.
00:14:35.300 | Now that we have identified,
00:14:37.680 | if you remember back from my initial slides,
00:14:40.840 | the next step is to understand this driving force better.
00:14:44.240 | And so that's where we're gonna spend
00:14:46.800 | most of the time doing that.
00:14:48.400 | And for that, we need to go back to some history
00:14:51.960 | of transformer 'cause this is a transformers class,
00:14:54.480 | analyze key structures and decisions
00:14:57.640 | that were made by the researchers at the time
00:15:00.240 | and why they did that,
00:15:01.640 | whether that was an optimal structure
00:15:03.520 | that could have been added at the time
00:15:05.520 | and why they might be irrelevant now
00:15:08.480 | and should we remove that.
00:15:10.000 | And we'll go through some of the practice of this.
00:15:12.120 | And hopefully this will give you some flavor
00:15:14.560 | of what like scaling research looks like.
00:15:18.420 | So now we'll go into a little bit of the technical stuff.
00:15:22.100 | Transformer architecture, there are some variants.
00:15:25.040 | I'll talk about three of them.
00:15:27.160 | First is the encoder decoder,
00:15:28.600 | which is the original transformer,
00:15:30.360 | which has a little bit more structure.
00:15:31.880 | Second one is the encoder only,
00:15:33.440 | which is popularized by Bert.
00:15:36.360 | And then third one is decoder only,
00:15:38.960 | which you can think of as a current like GPT-3
00:15:42.160 | or other language models.
00:15:43.480 | This has a lot less structure than the encoder decoder.
00:15:46.360 | So these are the three types we'll go into detail.
00:15:49.240 | Second, the encoder only is actually not that useful
00:15:52.800 | in the most general sense, it still has some place,
00:15:55.460 | but we will so just briefly go over that
00:15:58.240 | and then spend most of the time comparing one and three.
00:16:01.440 | So one has more structure,
00:16:03.260 | what's the implication of that and so on.
00:16:06.040 | So first of all, let's think about what a transformer is.
00:16:08.960 | Just at a very high level or first principles,
00:16:13.120 | what is a transformer is a sequence model
00:16:15.160 | and sequence model has an input of a sequence.
00:16:18.720 | So sequence of elements can be words or images or whatever.
00:16:23.720 | It's a very general concept.
00:16:25.500 | In this particular example, I'll show you with the words,
00:16:28.160 | sentence is a sequence of words.
00:16:30.040 | And then the first step is to tokenize it
00:16:32.320 | 'cause we have to represent this words in computers,
00:16:36.880 | which requires just some kind of a encoding scheme.
00:16:40.720 | So we just do it with a fixed number of integers
00:16:43.920 | that we have now sequence of integers.
00:16:46.280 | And then the dominant paradigm nowadays
00:16:49.160 | is to represent each sequence element as a vector,
00:16:52.120 | dense vector, because we know how to multiply them well.
00:16:55.080 | And then so we have a sequence of vectors.
00:16:57.760 | And finally, this sequence model will do the following.
00:17:01.940 | We just want to model the interaction
00:17:04.720 | between sequence elements.
00:17:06.400 | And we do that by let them take
00:17:08.980 | the dot product of each other.
00:17:10.520 | And if the dot product is high,
00:17:12.000 | we can say semantically they are more related
00:17:13.960 | than the dot products that is low.
00:17:16.160 | And that's kind of the sequence model.
00:17:18.080 | And the transformer is a particular type of sequence model
00:17:21.480 | that uses what's called the tension
00:17:23.780 | to model this interaction.
00:17:26.960 | So let's get into the details of this encoder decoder,
00:17:30.600 | which was the original transformer.
00:17:32.120 | It's quite many, many pieces.
00:17:33.500 | So let's go into a little bit, a piece at a time.
00:17:36.660 | So starting with the encoder.
00:17:38.280 | So here I'm going to show you an example
00:17:40.200 | of machine translation, which used to be very cool thing.
00:17:44.440 | And so you have an English sentence that is good,
00:17:47.840 | and then we're gonna translate into German.
00:17:50.080 | So first thing is to encode this into a dense vector.
00:17:54.000 | So here I'm representing it with this vector
00:17:57.680 | of size three or something.
00:17:59.200 | And then we have to let them take the dot product.
00:18:01.520 | So this lines represent which element can talk
00:18:06.160 | to which element, other elements.
00:18:08.560 | And here, because it's an input,
00:18:10.280 | we take what is called the bidirectional attention.
00:18:12.640 | So any token can talk to any other token.
00:18:14.800 | And then we have this MLP or feed forward layer,
00:18:17.900 | which is per token.
00:18:19.120 | It doesn't have any interaction.
00:18:20.540 | You just do some multiplication just because we can do it.
00:18:25.720 | And then that's one layer, and we repeat that n times.
00:18:29.560 | And that's just the transformer encoder.
00:18:31.880 | And at the end, what you get is the sequence of vectors,
00:18:36.200 | each representing the sequence element,
00:18:38.980 | in this case, a word.
00:18:40.720 | So that's the output of this encoder.
00:18:42.720 | Now let's look at the decoder,
00:18:44.220 | which is similarly shaped stack of layers.
00:18:47.600 | So here we put in as an input what the answer should be.
00:18:53.240 | So here, VOS is the beginning of sequence,
00:18:55.840 | and then das ist gut.
00:18:56.940 | I don't know how to pronounce it,
00:18:57.780 | but that's the German translation of that is good.
00:19:00.200 | And so we kind of go through the similar process.
00:19:03.080 | Here we have a causal self-attention,
00:19:05.160 | meaning that the tokens of time step T
00:19:08.120 | can only attend to T and before,
00:19:10.560 | because when we start generating it,
00:19:12.440 | we don't have the future tokens.
00:19:14.320 | So we cannot, when we train it, we should limit that.
00:19:17.400 | And that way, this is done by like masking,
00:19:20.440 | but it's just different from the encoder.
00:19:23.780 | So after this, you can get, after again, N layers,
00:19:28.620 | you get this sequence output,
00:19:31.240 | and you have this, the output is sequence.
00:19:33.860 | So sequence to sequence mapping,
00:19:35.220 | this is a general encoder-decoder architecture.
00:19:37.940 | And when you get this end of sequence,
00:19:39.960 | you stop generating it.
00:19:41.180 | So this is the overall picture.
00:19:43.380 | Now I'll point out some important attention patterns.
00:19:46.420 | So we are translating into German
00:19:50.440 | what is input to the encoder.
00:19:52.140 | So there has to be some connection
00:19:53.500 | between the decoder and the encoder.
00:19:55.740 | That is done by this cross-attention mechanism
00:19:57.900 | shown in this red,
00:19:59.020 | which is just that each vector's representation
00:20:02.180 | on each sequence in the output decoder
00:20:04.580 | should attend to some of them in the encoder.
00:20:06.780 | And that is done.
00:20:07.740 | In particular, the design feature,
00:20:10.060 | which is interesting is that all the layers in the decoder
00:20:13.900 | attend to the final layer output of the encoder.
00:20:16.980 | I will come back to the implication of this design.
00:20:20.460 | So yep, that's that.
00:20:22.940 | And now move on to the second type of architecture,
00:20:26.100 | which is encoder-only.
00:20:27.060 | We'll spend a little bit of time here.
00:20:28.420 | So again, we have the same input,
00:20:32.820 | and we go through a similar structure.
00:20:35.700 | And then in this case, the final output is a single vector.
00:20:39.480 | Regardless of the length of the sequence,
00:20:41.740 | we just get a single vector.
00:20:43.100 | And that is, that represent the input sequence.
00:20:47.180 | That's the dense vector representation.
00:20:49.500 | And then let's say we do some kind of a sentiment analysis.
00:20:52.540 | We run through a task-specific linear layer
00:20:54.940 | to map it to classification labels,
00:20:57.080 | positive or negative probabilities here.
00:20:59.540 | And that's required for all these task-specific cases.
00:21:04.540 | And this is kind of popularized by BERT.
00:21:06.620 | And what this means is that here at the time,
00:21:10.420 | 2018, when BERT came out,
00:21:12.680 | we had the benchmark called GLUE,
00:21:14.900 | which was a language understanding test.
00:21:16.620 | You have a sequence in,
00:21:17.820 | classification labels out for most cases.
00:21:19.900 | This was how the field really advanced at the time.
00:21:23.020 | So when we care about such tasks,
00:21:25.380 | then there's an incentive
00:21:26.940 | to think about simplifying the problem,
00:21:28.580 | adding the structure to the problem
00:21:29.980 | so that we can make a progress.
00:21:31.160 | So this, the additional structure
00:21:32.880 | that was put into this particular architecture
00:21:35.140 | is that we're gonna give up on the generation.
00:21:38.860 | If we do that, it becomes a lot simpler problem.
00:21:41.860 | Instead of sequence to sequence,
00:21:43.460 | we're talking about sequence to classification labels,
00:21:46.100 | and that's just so much easier.
00:21:47.820 | And so at some point, 2018, 2019,
00:21:51.780 | a lot of the papers are just research was like,
00:21:54.540 | we sometimes call it BERT engineers.
00:21:56.340 | It's a little bit change of something,
00:21:57.820 | get like 0.5% better on GLUE,
00:22:00.740 | and you get a paper and things like that.
00:22:02.200 | It was like very chaotic era.
00:22:04.020 | And, but if you look at from this perspective,
00:22:08.060 | we are putting the sequence structure
00:22:10.620 | of not generating the sequence
00:22:12.060 | that puts a lot of performance win,
00:22:15.100 | but in the long term, it's not really useful.
00:22:17.380 | So we're not gonna look
00:22:18.220 | at this encoder only architecture going forward.
00:22:21.340 | Third architecture, decoder only.
00:22:23.280 | This one is my favorite personally,
00:22:25.900 | and it looks kind of daunting,
00:22:28.500 | but because of this attention pattern,
00:22:30.620 | but it actually is very simple.
00:22:32.700 | So here we only have a single stack,
00:22:35.740 | and it can actually generate stuff.
00:22:37.940 | And so there's misconception
00:22:40.700 | that some people think this decoder only architecture
00:22:42.940 | is used for language modeling next to prediction.
00:22:45.340 | So it cannot be used for supervised learning,
00:22:47.220 | but here we can actually do it.
00:22:48.560 | The trick is to have this input that is good,
00:22:51.640 | concatenated with the target.
00:22:53.360 | And if you do that,
00:22:54.260 | then it just becomes simple to sequence in sequence out.
00:22:57.940 | So what we do is the self attention mechanism here
00:23:01.180 | is actually handling both the cross attention
00:23:04.340 | between target and the input,
00:23:05.740 | and self attention sequence learning within each.
00:23:09.300 | So that's the causal attention.
00:23:10.980 | And then, as I mentioned, the output is a sequence.
00:23:15.140 | And then the key design features are self attention,
00:23:17.820 | and so serving both roles.
00:23:19.540 | And we are, in some sense,
00:23:21.700 | sharing the parameters between input and target.
00:23:23.900 | So same set of parameters are applied
00:23:25.740 | to both input and the target sequences.
00:23:28.180 | So this is the decoder only.
00:23:29.860 | Now we will go into the comparison.
00:23:32.460 | So I think there are many,
00:23:35.420 | they look very different, at least on the schematics.
00:23:37.780 | So how different are they actually?
00:23:39.780 | And I argue that they're actually quite similar.
00:23:44.140 | And so to illustrate that,
00:23:45.880 | we're gonna transform starting from this encoder decoder,
00:23:48.580 | which has more structures built in,
00:23:50.420 | and then into the decoder only architecture,
00:23:53.540 | and see what are some of the differences.
00:23:56.100 | And then interpret those differences,
00:23:57.980 | those additional structures,
00:23:59.060 | are they relevant nowadays,
00:24:00.660 | now that we have more compute,
00:24:02.060 | better algorithm, and so on.
00:24:03.460 | So let's have this table.
00:24:06.380 | Four differences, we'll see each of them.
00:24:08.840 | And then as we go through, we'll populate this table.
00:24:12.620 | So let's first look at this additional cross-attention.
00:24:15.940 | What that means is that this, on the left,
00:24:18.380 | is an encoder decoder,
00:24:19.380 | which has this additional red block, the cross-attention,
00:24:22.040 | compared to the simpler one that doesn't have that.
00:24:24.220 | So we wanna make the left closer to the right.
00:24:28.020 | So that means we need to either get rid of it or something.
00:24:31.540 | And attention mechanism
00:24:33.060 | has kind of the four projection matrices.
00:24:35.900 | And so self-attention and cross-attention
00:24:38.180 | actually have the same number of parameters, same shape.
00:24:40.300 | So we can just share them.
00:24:41.460 | So that's the first step, share both of these.
00:24:43.620 | And then it becomes mostly the same mechanism.
00:24:46.340 | And then, so that's the first difference,
00:24:48.620 | separate cross-attention,
00:24:49.740 | or self-attention serving both roles.
00:24:52.020 | Second difference is the parameter sharing.
00:24:54.260 | So what that means is that
00:24:56.540 | between the input and the target,
00:24:58.660 | encoder decoder architecture uses a separate parameters.
00:25:01.460 | And decoder only has a single stack,
00:25:04.140 | so it uses the shared parameter.
00:25:05.900 | So if you wanna make the left close to right,
00:25:08.380 | we wanna share the encoder parameters.
00:25:10.860 | So let's do that, just color this.
00:25:12.540 | So now they share the parameters.
00:25:14.740 | Third difference is the target to input attention pattern.
00:25:17.860 | So we need to connect the target to the input,
00:25:20.400 | and how is that done?
00:25:22.460 | In the encoder decoder case, we had this cross-attention,
00:25:25.500 | and then in the decoder only,
00:25:27.420 | it's the self-attention doing everything.
00:25:31.140 | The difference is that we have this,
00:25:35.280 | every layer of the decoder attending
00:25:37.420 | to the final layer output of the encoder.
00:25:39.880 | Whereas if you think about this decoder,
00:25:41.780 | it's actually per layer, within layer.
00:25:43.960 | When we are decoding the, say, word DOS,
00:25:47.700 | we are looking at the same layer representation
00:25:50.620 | of the encoder, and that's within layer,
00:25:53.580 | and I think this is the design feature.
00:25:55.360 | So if you wanna make this close to that,
00:25:57.220 | we have to bring back this attention to each layer.
00:26:00.700 | So now layer one will be attending to layer one of this.
00:26:04.760 | And finally, the last difference is the input attention.
00:26:09.700 | I mentioned about this bidirectional attention,
00:26:11.860 | and because we have this decoder only,
00:26:14.580 | typically with the unidirectional attention,
00:26:17.580 | we need to make them matching.
00:26:19.140 | So that's the, we can just get rid of it.
00:26:21.540 | I just got rid of some of the arrows.
00:26:24.180 | So then at this point,
00:26:26.540 | these two architectures are almost identical.
00:26:29.820 | A little bit of difference in the cross-attention,
00:26:31.320 | but same number of parameters,
00:26:33.260 | and if you have, in deep learning,
00:26:35.180 | if you just train this,
00:26:36.480 | these two architecture in the same task, same data,
00:26:38.780 | I think you will get pretty much within the noise,
00:26:40.580 | probably closer than if you train the same thing twice.
00:26:43.300 | So I would say they are identical.
00:26:46.180 | And so these are the main differences.
00:26:48.260 | Now we'll look at what are the additional structures,
00:26:51.800 | what they mean, what they mean, speed means.
00:26:54.500 | So yeah, that's the populated table now.
00:26:57.480 | And then, so we can say that encoder-decoder,
00:27:00.380 | compared to the decoder-only architecture,
00:27:02.400 | has these additional structures in the devices built in.
00:27:07.400 | So let's go into each of them.
00:27:09.340 | The first one is the,
00:27:11.140 | what encoder-decoder tries at it as a structure
00:27:14.300 | is that input and the target sequences
00:27:16.260 | are sufficiently different that we,
00:27:18.580 | it'll be useful to use a separate parameters.
00:27:21.140 | That's the assumption.
00:27:22.460 | And so why is that useful?
00:27:24.740 | When can that assumption be useful?
00:27:27.540 | And one example is machine translation.
00:27:30.420 | Back when the transform was introduced in 2017,
00:27:33.040 | translation was a really popular task.
00:27:35.240 | And it was difficult, considered difficult.
00:27:37.240 | And because it's just sequence to sequence,
00:27:40.000 | and you can actually have a blue score,
00:27:41.640 | which is heuristic-based method
00:27:43.280 | that can give you a single number,
00:27:44.920 | and then people can optimize that.
00:27:46.720 | So in that task, we have this input and target
00:27:51.060 | in completely different languages.
00:27:52.620 | So if the goal is to learn translation only,
00:27:55.480 | then it kind of makes sense to have,
00:27:57.120 | okay, this parameter in the encoder
00:27:58.760 | will take care of the English,
00:28:00.500 | and this parameter in the decoder
00:28:01.940 | will take care of the German.
00:28:03.220 | That seems natural.
00:28:04.700 | And what about now?
00:28:06.860 | Modern language models is just about learning knowledge.
00:28:10.060 | And it's not just about translation,
00:28:11.860 | or not even about language.
00:28:13.080 | Language just comes up as a byproduct
00:28:15.060 | of doing this next token prediction,
00:28:17.460 | and translation as well.
00:28:19.180 | So does it make sense to have a separate parameter
00:28:22.200 | for this kind of situation now?
00:28:24.820 | Like we have some knowledge in German,
00:28:28.100 | some knowledge in English,
00:28:29.720 | and if anything, you wanna combine them,
00:28:31.800 | and if we represent them in a separate parameters,
00:28:34.440 | I don't think that's natural.
00:28:35.920 | So I would say with this much more general,
00:28:38.980 | larger models that can do a lot of things,
00:28:42.400 | this assumption seems very unnatural to me.
00:28:45.540 | Second example is a little bit more modern.
00:28:48.220 | Two years ago, when I was at Google,
00:28:50.420 | and with Jason, we did this instruction fine-tuning work,
00:28:54.100 | and what this is, is you take the pre-trained model,
00:28:57.460 | and then just fine-tune on academic dataset,
00:29:00.020 | and so that it can understand
00:29:02.220 | the natural language instruction.
00:29:03.620 | So the detail doesn't matter,
00:29:05.340 | but here, let's think about the performance gain
00:29:09.000 | by doing this fine-tuning
00:29:10.300 | on two different architectures we tried.
00:29:12.700 | So first five is the Flan T5,
00:29:15.180 | which is T5-based, which is encoder-decoder architecture.
00:29:18.300 | Last one, the latter five,
00:29:20.660 | decoder-only architecture based on POM.
00:29:22.780 | So we spent 99% of the time on POM,
00:29:27.020 | optimizing a lot of these.
00:29:28.780 | And then at the end, we just spent like three days on T5,
00:29:31.540 | but the performance gain was a lot higher on this.
00:29:34.660 | And I was really confused about this,
00:29:36.620 | and in a very good way.
00:29:38.120 | And after the paper was published,
00:29:39.740 | I wanted to dig a little bit deeper
00:29:41.860 | into why this might be the case.
00:29:43.340 | So my hypothesis is that it's about the length.
00:29:48.340 | So academic datasets we use, we use like 1,832 tasks,
00:29:52.700 | and here, they have this very distinctive characteristic
00:29:56.700 | where we have a long input,
00:29:58.620 | long in order to make the task more difficult,
00:30:01.060 | but then we cannot make the target long,
00:30:02.740 | because if we do, there's no way to grade it.
00:30:05.980 | So there's fundamental challenge of that.
00:30:07.900 | So what happens is you have a long text of input
00:30:10.620 | and then short text of the target.
00:30:12.380 | And so this is kind of the length distribution
00:30:15.460 | of what went into the Flan fine-tuning.
00:30:18.340 | So then you see this,
00:30:21.140 | you have a very different sequence
00:30:23.540 | going into the encoder as an input,
00:30:25.620 | and a very different type of sequence going into the target.
00:30:28.180 | So now this encoder-decoder architecture
00:30:31.060 | has an assumption that they will be very different.
00:30:33.340 | That structure really shines because of this.
00:30:36.700 | It was a kind of an accident,
00:30:38.260 | but that was, I think, why this really architecture
00:30:41.780 | was just suitable for fine-tuning with the academic datasets.
00:30:46.780 | What about now?
00:30:47.860 | Do we care about this kind of assumption?
00:30:50.660 | And if you think about the general use cases
00:30:53.100 | of language models nowadays,
00:30:55.140 | if anything, the more interesting cases
00:30:57.300 | involve longer generation, longer target.
00:31:00.860 | Just because we cannot grade them
00:31:03.300 | doesn't mean that we are not interested in them.
00:31:05.500 | Actually, if anything, we are more interested in that.
00:31:07.500 | So now we have this longer target situation.
00:31:10.500 | So this separate sequence length parameter
00:31:13.500 | doesn't seem to make much sense.
00:31:15.380 | And moreover, we think about this chat application,
00:31:18.780 | like ChatGPT, we do multi-turn conversation.
00:31:21.620 | And then, so what is a target of this turn
00:31:24.500 | becomes the input of the next turn?
00:31:26.540 | And then my question is, does that make sense
00:31:29.180 | to even think about different parameters
00:31:32.580 | if next turn it's gonna be the same thing?
00:31:34.700 | So that was the first inductive bias we just mentioned.
00:31:40.540 | And then the second structure is that target element
00:31:42.900 | can only attend to the fully encoded ones,
00:31:46.180 | the final output of the encoder.
00:31:47.980 | Let's look at this additional structure, what that means.
00:31:50.580 | So as I mentioned,
00:31:52.100 | we have this very top layer attending to it.
00:31:56.020 | And so in deep neural nets,
00:31:58.740 | typically we see that the bottom layers
00:32:00.700 | and the top layers encode information
00:32:02.900 | at a very different level.
00:32:04.580 | Meaning that, for example, in computer vision,
00:32:06.940 | lower layer, bottom layers encode something like edges,
00:32:10.340 | top layers, higher levels, combining the features,
00:32:13.180 | something like cat face.
00:32:14.620 | And so we call this deep learning
00:32:16.020 | a hierarchical representation learning method.
00:32:19.700 | And so now the question is,
00:32:21.780 | if decoder layer one attends to encoder final layer,
00:32:26.620 | which probably has a very different level of information,
00:32:29.380 | is that some kind of an information bottleneck,
00:32:31.660 | which actually motivated the original attention mechanism?
00:32:35.820 | And in practice, I would say, in my experience,
00:32:39.020 | doesn't really make any difference.
00:32:40.540 | And that's because my experience was limited to,
00:32:43.100 | say, 24 layers of encoder of T5.
00:32:46.420 | So layer one attended to 24, probably fine.
00:32:49.020 | But what if we have 10X or 1,000X more layers?
00:32:52.020 | Would that be problematic?
00:32:53.260 | I'm not really comfortable with that.
00:32:55.860 | So I think this is also unnecessary design
00:32:59.780 | that maybe we need to revisit.
00:33:01.500 | Final structure we're gonna talk about is the,
00:33:05.380 | when we do this, there's like a bidirectional thing
00:33:08.100 | in the encoder-decoder.
00:33:09.060 | Let's think about that.
00:33:10.540 | So yeah, bidirectional input attention,
00:33:13.360 | is that really necessary?
00:33:14.660 | So when we had this BERT,
00:33:18.460 | B in BERT stands for bidirectional.
00:33:21.020 | 2018, when we were solving that question answering squat,
00:33:24.340 | actually it was very difficult task.
00:33:26.020 | So if you have any additional trick,
00:33:28.260 | it can make a huge difference.
00:33:29.700 | Bidirectionality was a really useful,
00:33:32.700 | I think maybe boosting up the squat score by like 20.
00:33:35.980 | So it was really huge thing.
00:33:37.580 | But at scale, I don't think this matters that much.
00:33:40.540 | This is my highly anecdotal experience.
00:33:43.380 | So we did, in flan two, we tried both bidirectional
00:33:47.140 | and unidirectional fine tuning,
00:33:49.140 | didn't really make much difference.
00:33:50.700 | So, but I wanna point out this bidirectionality,
00:33:54.540 | actually bring in an engineering challenge
00:33:56.460 | for modern multi-turn chat application.
00:33:59.500 | So at every turn, the new input has to be encoded again.
00:34:03.580 | And for union directional attention is much, much better.
00:34:05.980 | So here's what I mean by that.
00:34:07.460 | So let's think about this more modern conversation
00:34:10.700 | between user and system.
00:34:12.860 | How are you bad and why?
00:34:14.580 | And so here, if we think about the bidirectional case,
00:34:17.340 | we will, and when we generate bad,
00:34:19.980 | we need to encode this input with the bidirectional thing,
00:34:23.220 | which is fine.
00:34:24.140 | And then after the bad is generated,
00:34:27.460 | when we're trying to generate why,
00:34:29.300 | we'll need to encode how again,
00:34:32.060 | because how can attend to bad.
00:34:34.020 | So we need to do everything from scratch again.
00:34:36.580 | In contrast, if we do unidirectional one,
00:34:39.700 | we can do much, much better,
00:34:41.900 | because now when we are trying to generate why,
00:34:44.780 | we don't have to redo how,
00:34:46.100 | because we cannot attend to the future tokens,
00:34:49.900 | so we don't have to do anything.
00:34:51.500 | If you see the difference, this part can be cached,
00:34:54.100 | and then this part is the only thing
00:34:55.780 | that has to be encoded again.
00:34:58.500 | So this kind of makes a big difference
00:35:00.140 | when we think about multiple turns going in.
00:35:03.660 | So I would say bidirectional attention did well in 2018,
00:35:07.700 | which is mostly solved by scale,
00:35:09.660 | and now because of this engineering challenge,
00:35:11.860 | we don't really need that.
00:35:13.460 | So to conclude, we have looked into this driving force,
00:35:17.940 | dominant driving force governing this AI research,
00:35:20.660 | and that was this exponentially cheaper compute
00:35:23.380 | and associated scaling effort.
00:35:25.820 | And so to understand this driving force,
00:35:28.140 | we analyze some of the additional structures
00:35:30.500 | added to the encoded decoder compared to the decoder only,
00:35:33.540 | and then thought about what that means
00:35:36.620 | from the perspective of scaling.
00:35:38.460 | And I wanted to just conclude with this remark.
00:35:42.460 | So we have looked at these kind of analysis,
00:35:44.900 | which are all, one can say this is just historical artifacts
00:35:48.860 | and doesn't matter, but if you do many of these,
00:35:51.820 | now you look at this current events,
00:35:53.940 | you can hopefully think about those in a more unified manner
00:35:57.520 | and then see, okay, what assumptions in my problem
00:36:01.580 | that I need to revisit, and are they relevant?
00:36:04.020 | And if not, why?
00:36:05.140 | And you have an answer to it.
00:36:06.260 | Is can we do it with a more general thing and scale up?
00:36:09.380 | And so I hope you can go back
00:36:11.620 | and really think about these problems,
00:36:13.620 | and together we can really shape the future of AI
00:36:16.620 | in a really nice way.
00:36:18.180 | So that's it, thanks.
00:36:20.080 | (audience applauding)
00:36:23.760 | (chewing)
00:36:25.320 | (upbeat music)