back to indexStanford CS224N | 2023 | Lecture 10 - Prompting, Reinforcement Learning from Human Feedback
00:00:00.000 |
Okay, awesome. We're going to get started. So my name is Jesse Mu. I'm a PhD student 00:00:13.400 |
in the CS department here working with the NLP group and really excited to be talking 00:00:18.760 |
about the topic of today's lecture, which is on prompting instruction, fine tuning, 00:00:23.480 |
and RLHF. So this is all stuff that has been super hot recently because of all the latest 00:00:30.480 |
criticise about chatbots, chat-dbt, etc. And we're going to hopefully get somewhat of an 00:00:36.480 |
understanding as to how these systems are trained. 00:00:40.360 |
Okay, so before that, some course logistics things. So project proposals, both custom 00:00:45.760 |
and final, were due a few minutes ago. So if you haven't done that, this is a nice reminder. 00:00:52.560 |
We're in the process of assigning mentors of projects, so we'll give feedback soon. 00:00:57.840 |
Besides that, assignment five is due on Friday at midnight. We still recommend using Colab 00:01:03.320 |
for the assignments, even if you've had AWS or Azure credits granted. If that doesn't 00:01:07.960 |
work, there's instructions for how to connect to a Kaggle notebook where you will also be 00:01:11.200 |
able to use GPUs. Look for that post on ed. And then finally, also just posted on ed by 00:01:16.760 |
John is a course feedback survey. So this is part of your participation grade. Please 00:01:28.440 |
Okay, so let's get into this lecture, which is going to be about what we are trying to 00:01:34.920 |
do with these larger and larger models. Over the years, the compute for these models have 00:01:40.680 |
just gone up hundreds of powers of 10, trained on more and more data. So larger and larger 00:01:49.680 |
models, they're seeing more and more data. And in lecture 10, if you recall this slide, 00:01:56.520 |
we talked a little bit about what happens when you do pre-training. And as you begin 00:02:01.680 |
to really learn to predict the missing sentence in certain texts, right? You learn things 00:02:06.240 |
like syntax, co-reference, sentiment, et cetera. But in this lecture, we're going to take it 00:02:11.840 |
a little bit further and really take this idea to its logical conclusion. So if you 00:02:15.680 |
really follow this idea of we're just going to train a giant language model on all of 00:02:20.400 |
the world's text, you really begin to see language models sort of in a way as rudimentary 00:02:25.680 |
world models. So maybe they're not very good at world models, but they kind of have to 00:02:29.480 |
be doing some implicit world modeling just because we have so much information on the 00:02:33.600 |
internet and so much of human collective knowledge is transcribed and written for us on the internet, 00:02:38.680 |
right? So if you are really good at predicting the next word in text, what do you learn to 00:02:42.720 |
do? There's evidence that these large language models are to some degree learning to represent 00:02:47.960 |
and think about agents and humans and the beliefs and actions that they might take. 00:02:52.600 |
So here's an example from our recent paper where we are talking about someone named Pat 00:02:57.800 |
watching a demonstration of a bowling ball and a leaf being dropped at the same time 00:03:01.520 |
in a vacuum chamber. And the idea is here we're saying Pat is a physicist, right? So 00:03:07.320 |
Pat is a physicist and we ask for the language models next continuation of this sentence. 00:03:13.520 |
Because he's a physicist, we do some inference about what kind of knowledge Pat has and Pat 00:03:17.560 |
will predict that the bowling ball and the leaf will fall at the same time. But if we 00:03:21.880 |
change the sentence of the prompt and we say, well, Pat has actually never seen this demonstration 00:03:25.400 |
before, then Pat will predict that the bowling ball will fall to the ground first, which 00:03:29.600 |
is wrong, right? So if you get really good at predicting the next sentence in text, you 00:03:33.880 |
also to some degree have to learn to predict an agent's beliefs, their backgrounds, common 00:03:39.080 |
knowledge and what they might do next. So not just that, of course, if we continue browsing 00:03:44.920 |
the internet, we see a lot of encyclopedic knowledge. So maybe language models are actually 00:03:48.960 |
good at solving math reasoning problems if they've seen enough demonstrations of math 00:03:53.080 |
on the internet. Code, of course, code generation is a really exciting topic that people will 00:03:58.560 |
that people are looking into and we'll give a presentation on that in a few weeks. Even 00:04:04.120 |
medicine, right? We're beginning to think about language models trained on medical texts 00:04:07.680 |
and being applied to the sciences and whatnot. So this is what happens when we really take 00:04:11.920 |
this language modeling idea seriously. And this has resulted in a resurgence of interest 00:04:17.760 |
in building language models that are basically assistants, right? You can give them any task 00:04:22.960 |
under the sun. I want to create a three course meal and a language model should be able to 00:04:27.960 |
take a good stab at being able to do this. This is kind of the promise of language modeling. 00:04:34.320 |
But of course, there's a lot of steps required to get from this, from our basic language 00:04:38.720 |
modeling objective. And that's what this lecture is going to be about. So how do we get from 00:04:44.040 |
just predicting the next word in a sentence to something like chat GPT, which you can 00:04:48.600 |
really ask it to do anything and it might fail sometimes, but it's getting really, really 00:04:52.400 |
convincingly good at some things. Okay. So this is the lecture plan. Basically, I'm going 00:04:58.760 |
to talk about as we're working with these large language models, we come up with kind 00:05:02.360 |
of increasingly complex ways of steering the language models closer and closer to something 00:05:06.600 |
like chat GPT. So we'll start with zero shot and few shot learning, then instruction, fine 00:05:11.120 |
tuning and then reinforcing learning from human feedback or RLHF. Okay. So let's first 00:05:21.000 |
talk about few shot and zero shot learning. And in order to do so, we're again going to 00:05:26.480 |
kind of build off of the pre-training lecture last Tuesday. So in the pre-training lecture, 00:05:30.800 |
John talked about these models like GPT, generative pre-trained transformer, that are these decoder 00:05:38.200 |
only language models. So they're just trained to predict the next word in a corpus of text. 00:05:43.600 |
And back in 2018 was the first iteration of this model. And it was 117 million parameters. 00:05:50.000 |
So at the time it was pretty big. Nowadays, it's definitely much smaller. And again, it's 00:05:54.400 |
just a vanilla transformer decoder using the techniques that you've seen. And it's trained 00:05:58.280 |
on a corpus of books. So about 4.6 gigabytes of text. And what GPT showed was the promise 00:06:04.840 |
at doing this simple language modeling objective and serving as an effective pre-training technique 00:06:10.200 |
for various downstream tasks that you might care about. So if you wanted to apply it to 00:06:13.920 |
something like natural language inference, you would take your premise sentence and your 00:06:17.720 |
hypothesis sentence, concatenate them, and then maybe train a linear classifier on the 00:06:24.960 |
OK, but that was three, four, five years ago. What has changed since then? So they came 00:06:34.520 |
out with GPT-2. So GPT-2 was released the next year in 2019. This is 1.5 billion parameters. 00:06:41.720 |
So it's the same architecture as GPT, but just an order of magnitude bigger. And also 00:06:46.620 |
trained on much more data. So we went from 4 gigabytes of books to 40 gigabytes of internet 00:06:53.040 |
text data. So they produced a data set called WebText. This is produced by scraping a bunch 00:06:57.560 |
of links to comments on Reddit. So the idea is that the web contains a lot of spam, maybe 00:07:01.920 |
a lot of low-quality information. But they took links that were posted on Reddit that 00:07:05.680 |
had at least a few upvotes. So humans maybe looked through it and said, you know, this 00:07:09.240 |
is a useful post. So that was kind of a rough proxy of human quality. And that's how they 00:07:14.120 |
collected this large data set. And so if you look at the size of GPT in 2018, we can 00:07:20.920 |
draw a bigger dot, which is the size of GPT-2 in 2019. And one might ask, how much better 00:07:30.500 |
So the authors of GPT-2 titled their paper, "Language Models are Unsupervised Multitask 00:07:35.600 |
Learners." And that kind of gives you a hint as to what the key takeaway they found was, 00:07:40.240 |
which is this unsupervised multitasking part. 00:07:44.640 |
So basically, I think the key takeaway from GPT-2 was this idea that language models can 00:07:49.920 |
display zero-shot learning. So what I mean by zero-shot learning is you can do many tasks 00:07:55.720 |
that the model may not have actually explicitly been trained for with no gradient updates. 00:08:00.360 |
So you just kind of query the model by simply specifying the right sequence prediction problem. 00:08:06.340 |
So if you care about question answering, for example, you might include your passage, like 00:08:10.000 |
a Wikipedia article about Tom Brady. And then you'll add a question, so a question, where 00:08:14.000 |
was Tom Brady born? And then include an answer, like A followed by a colon. And then just 00:08:18.720 |
ask the model to predict the next token. You've kind of jury-rigged the model into doing question 00:08:26.720 |
For other tasks, like classification tasks, another thing you can do is compare different 00:08:30.440 |
probabilities of sequences. So this task is called the Winograd Schema Challenge. It's 00:08:35.080 |
a pronoun resolution task. So the task is to kind of resolve a pronoun which requires 00:08:39.620 |
some world knowledge. So one example is something like, the cat couldn't fit into the hat because 00:08:44.720 |
it was too big. And the question is whether it refers to the cat or to the hat. And in 00:08:50.480 |
this case, it makes most sense for it to refer to the cat because things fitting into things 00:08:55.400 |
because they're too big, you need to use some world knowledge to kind of resolve that. 00:09:00.060 |
So the way that you get zero-shot predictions for this task out of a language model like 00:09:03.560 |
GPT-2 is you just ask the language model, which sequence is more likely? Is the probability 00:09:10.080 |
of the cat couldn't fit into the hat because the cat was too big deemed more likely by 00:09:15.080 |
the language model than the probability that the cat couldn't fit into the hat because 00:09:19.760 |
the hat was too big? You can score those sequences because this is a language model. And from 00:09:24.120 |
there, you get your zero-shot prediction. And you can end up doing fairly well on this 00:09:31.960 |
OK. Yeah, so digging a little bit more into the results, GPT-2 at the time beat the state 00:09:39.480 |
of the art on a bunch of language modeling benchmarks with no task-specific fine-tuning. 00:09:44.000 |
So no traditional fine-tune on a training set and then test on a testing set. So here's 00:09:48.800 |
an example of such a task. This is a language modeling task called Lambada, where the goal 00:09:52.480 |
is to predict a missing word. And the idea is that the word that you need to predict 00:09:56.600 |
depends on some discourse earlier in the sentence or earlier a few sentences ago. And by simply 00:10:03.560 |
training your language model and then running it on the Lambada task, you end up doing better 00:10:07.880 |
than the supervised fine-tuned state of the art at the time and across a wide variety 00:10:18.640 |
OK. Another kind of interesting behavior they observed-- and so you'll see hints of things 00:10:27.240 |
that we now take for granted in this paper-- is that you can get interesting zero-shot 00:10:31.200 |
behavior as long as you take some liberties with how you specify the task. So for example, 00:10:36.360 |
let's imagine that we want our model to do summarization. Even though GPT-2 was just 00:10:40.880 |
a language model, how can we make it do summarization? 00:10:44.880 |
The idea they explored was we're going to take an article, some news article, and then 00:10:48.960 |
at the end, we're going to append the TLDR sign, the TLDR token. So this stands for Too 00:10:53.920 |
Long Didn't Read. It's used a lot on Reddit to just say, if you didn't want to read the 00:10:57.460 |
above stuff, here's a few sentences that summarizes it. 00:11:00.920 |
So if you ask the model to predict what follows after the TLDR token, you might expect that 00:11:06.600 |
it'll generate some sort of summary. And this is kind of early whispers at this term that 00:11:12.040 |
we now call prompting, which is thinking of the right way to define a task such that your 00:11:17.160 |
model will do the behavior that you want it to do. 00:11:22.240 |
So if we look at the performance we actually observed on this task, here at the bottom 00:11:25.960 |
is a random baseline. So you just select three sentences from the article. And the scores 00:11:31.040 |
that we're using here are Rouge scores, if you remember the natural language generation 00:11:34.040 |
lecture. GPT-2 is right above. So it's not actually that good. It only does maybe a little 00:11:40.040 |
bit or barely any better than the random baseline. But it is approaching approaches that are 00:11:46.160 |
supervised approaches that are actually explicitly fine-tuned to do summarization. 00:11:52.120 |
And of course, at the time, it still underperformed the state of the art. But this really showed 00:11:56.180 |
the promise of getting language models to do things that maybe they weren't trained 00:12:01.520 |
OK, so that was GPT-2. That was 2019. Now here's 2020, GPT-3. So GPT-3 is 175 billion 00:12:13.000 |
parameters. So it's another increase in size by an order of magnitude. And at the time, 00:12:17.280 |
it was unprecedented. I think it still is kind of overwhelmingly large for most people. 00:12:22.320 |
And data. So they scaled up the data once again. 00:12:24.720 |
OK, so what is this by you? This paper's title was called Language Models are Few Shot Learners. 00:12:30.960 |
So what does that mean? So the key takeaway from GPT-3 was emergent few-shot learning. 00:12:37.480 |
So the idea here is, sure, GPT can still do zero-shot learning. But now you can specify 00:12:43.200 |
a task by basically giving examples of the task before asking it to predict the example 00:12:51.000 |
So this is often called in-context learning to stress that there are no gradient updates 00:12:55.540 |
being performed when you learn a new task. You're basically kind of constructing a tiny 00:12:59.580 |
little training data set and just including it in the prompt, including it in the context 00:13:03.520 |
window of your transformer, and then asking it to pick up on what the task is and then 00:13:07.500 |
predict the right answer. And this is in contrast to a separate literature on few-shot learning, 00:13:13.600 |
which assumes that you can do gradient updates. In this case, it's really just a frozen language 00:13:21.340 |
So few-shot learning works, and it's really impressive. So here's a graph. SuperGLUE here 00:13:26.560 |
is a kind of a wide coverage natural language understanding benchmark. And what they did 00:13:30.360 |
was they took GPT-3, and this data point here is what you get when you just do zero-shot 00:13:36.480 |
learning with GPT-3. So you provide an English description of a task to be completed, and 00:13:44.960 |
Just by providing one example, so one shot, you get like a 10% accuracy increase. So you 00:13:49.840 |
give not only the natural language task description, but also an example input and an example output, 00:13:55.520 |
and you ask it to code the next output. And as you increase to more shots, you do get 00:14:01.000 |
better and better scores, although, of course, you get diminishing returns after a while. 00:14:06.480 |
But what you can notice is that few-shot GPT-3, so no gradient updates, is doing as well as 00:14:12.000 |
or outperforming BERT fine-tuned on the SuperGLUE task explicitly. 00:14:23.680 |
So one thing that I think is really exciting is that you might think, OK, a few-shot learning, 00:14:28.840 |
whatever, it's just memorizing. Maybe there's a lot of examples of needing to do a few-shot 00:14:32.200 |
learning in the internet text data. And that's true, but I think there's also evidence that 00:14:37.320 |
GPT-3 is really learning to do some sort of on-the-fly optimization or reasoning. 00:14:43.560 |
And so the evidence for this comes in the form of these synthetic word unscrambling 00:14:46.840 |
tasks. So the authors came up with a bunch of simple kind of letter manipulation tasks 00:14:51.480 |
that are probably unlikely to exist in internet text data. So these include things like cycling 00:14:56.880 |
through the letters to get the kind of uncycled version of a word, so converting from P-L-E-A-P 00:15:02.160 |
to Apple, removing characters added to a word, or even just reversing words. 00:15:07.800 |
And what you see here is performance as you do few-shot learning as you increase the model 00:15:12.240 |
size. And what you can see is that the ability to do few-shot learning is kind of an emergent 00:15:19.160 |
property of model scale. So at the very largest model, we're actually seeing a model be able 00:15:27.160 |
I've noticed the reversed words are horrible, like the performance. 00:15:36.160 |
Yeah. Yeah. So the question was the reversed words. Mine is still low. Yeah, that's an 00:15:42.160 |
example of a task that these models still can't solve yet, although I'm not sure if 00:15:46.600 |
we've evaluated it with newer and newer models. Maybe the latest versions can indeed actually 00:15:51.360 |
Is there some intuition for why this emerges as a result of model scale? 00:15:56.160 |
I think that's a highly active area of research, and there's been papers published every week 00:16:00.360 |
on this. So I think there's a lot of interesting experiments that really try to dissect either 00:16:05.120 |
with synthetic tasks, like can GPT-3 learn linear regression in context? And there's 00:16:10.720 |
some model interpretability tasks, like what in the attention layers or what in the hidden 00:16:14.560 |
states are resulting in this kind of emergent learning. But yeah, I'd have to just refer 00:16:19.160 |
you to the recent literature on that. Anything else? Awesome. 00:16:27.000 |
Okay, so just to summarize, traditional fine tuning here is on the right. We take a bunch 00:16:32.280 |
of examples of a task that we care about. We give it to our model, and then we do a 00:16:35.680 |
gradient step on each example. And then at the end, we hopefully get a model that can 00:16:39.080 |
do well on some outputs. And in this new kind of paradigm of just prompting a language model, 00:16:43.640 |
we just have a frozen language model, and we just give some examples and ask the model 00:16:53.320 |
So you might think, and you'd be right, that there are some limits of prompting. Well, 00:16:57.040 |
there's a lot of limits of prompting, but especially for tasks that are too hard. There 00:17:00.680 |
are a lot of tasks that maybe seem too difficult, especially ones that involve maybe richer 00:17:04.840 |
reasoning steps or needing to synthesize multiple pieces of information. And these are tasks 00:17:10.060 |
that humans struggle with too. So one example is GPT-3. I don't have the actual graph here, 00:17:16.620 |
but it was famously bad at doing addition for much larger digits. And so if you prompt 00:17:21.960 |
GPT-3 with a bunch of examples of addition, it won't do it correct. But part of the reason 00:17:27.000 |
is because humans are also pretty bad at doing this in one step. Like if I asked you to just 00:17:31.280 |
add these two numbers on the fly and didn't give you a pencil and paper, you'd have a 00:17:37.520 |
So one observation is that you can just change the prompts and hopefully get some better 00:17:44.100 |
So there's this idea of doing chain of thought prompting, where in standard prompting, we 00:17:49.700 |
give some examples of a task that we'd like to complete. So here is an example of a math 00:17:53.380 |
word problem. And I told you that what we would do is we would give the question and 00:17:58.180 |
then the answer. And then for a data point that we actually care about, we ask the model 00:18:03.320 |
to predict the answer. And the model will try to produce the right answer, and it's 00:18:09.780 |
So the idea of chain of thought prompting is to actually demonstrate what kind of reasoning 00:18:14.040 |
you want the model to complete. So in your prompt, you not only put the question, but 00:18:20.060 |
you also put an answer and the kinds of reasoning steps that are required to arrive at the correct 00:18:24.700 |
answer. So here is actually some reasoning of how you actually would answer this tennis 00:18:28.400 |
ball question and then get the right answer. And because the language model is incentivized 00:18:33.540 |
to just follow the pattern and continue the prompt, if you give it another question, it 00:18:38.260 |
will in turn produce an answer, sorry, a rationale followed by an answer. 00:18:44.860 |
So you're kind of asking the language model to work through the steps yourself. And by 00:18:48.640 |
doing so, you end up getting some questions right when you otherwise might not. 00:18:54.620 |
So a super simple idea, but it's shown to be extremely effective. So here is this middle 00:18:59.940 |
school math word problems benchmark. And again, as we scale up the model for GPT and 00:19:04.540 |
some other kinds of models, being able to do chain of thought prompting emerges. So 00:19:10.020 |
we really see a performance approaching that of supervised baselines for these larger and 00:19:17.220 |
Seemingly the problem with the addition of the large numbers, do you have results on 00:19:35.900 |
how the chain of thought prompting for the larger numbers that middle school math word 00:19:37.900 |
Yeah. So the question is, does chain of thought prompting work for those addition problems 00:19:39.020 |
that I had presented? Yeah. There should be some results in the actual paper. They're 00:19:44.220 |
just not here, but you can take a look. Yeah. 00:19:48.540 |
Intuition of how the model was trained without doing gradient update? 00:19:53.340 |
Intuition about how the model is learning without gradient updates. Yeah. So this is 00:19:56.540 |
related to the question asked earlier about how is this actually happening. That is, yeah, 00:20:02.220 |
again, it's an active area of research. So my understanding of the literature is something 00:20:06.660 |
like you can show that models are kind of almost doing in-context gradient descent as 00:20:11.140 |
it's encoding a prompt. And you can analyze this with model interpretability experiments. 00:20:16.820 |
But I'm happy to suggest papers afterwards that kind of deal with this problem more carefully. 00:20:26.220 |
Cool. Okay. So a follow up work to this asked the question of, do we actually even need 00:20:35.740 |
examples of reasoning? Do we actually even need to collect humans working through these 00:20:39.860 |
problems? Can we actually just ask the model to reason through things? Just ask it nicely. 00:20:45.980 |
So this introduced this idea called zero shot chain of thought prompting. And it was honestly 00:20:49.820 |
like I think probably like the highest impact to simple idea ratio I've seen in a paper 00:20:55.260 |
where it's like the simplest possible thing where instead of doing this chain of thought 00:20:58.620 |
stuff, you just ask the question and then the answer, you first prepend the token, let's 00:21:03.740 |
think step by step. And the model will decode as if it had said, let's think step by step. 00:21:09.540 |
And it will work through some reasoning and produce the right answer. So does this work 00:21:15.900 |
on some arithmetic benchmarks? Here's what happens when you prompt the model just zero 00:21:20.300 |
shot. So just asking it to produce the answer right away without any reasoning. A few shots 00:21:24.900 |
are giving some examples of inputs and outputs. And this is zero shot chain of thought. So 00:21:29.820 |
just asking the model to think through things, you get crazy good accuracy. When we compare 00:21:35.980 |
to actually doing manual chain of thought, you still do better with manual chains of 00:21:39.620 |
thought. But that just goes to show you how simple of an idea this is and ends up producing 00:21:44.620 |
improved performance numbers. So the funny part of this paper was why use let's think 00:21:51.620 |
by step by step. They used actually a lot of prompts and tried them out. So here's zero 00:21:55.860 |
shot baseline performance. They tried out a bunch of different prefixes, the answers 00:22:00.340 |
after the proof. Let's think. Let's think about this logically. And they found that 00:22:04.140 |
let's think step by step was the best one. It turns out this was actually built upon 00:22:09.260 |
later in the year where they actually use a language model to search through the best 00:22:12.860 |
possible strings that would maximize performance on this task, which is probably gross overfitting. 00:22:18.420 |
But the best prompt they found was let's work this out step by step in a step by step way 00:22:23.060 |
to be sure that we have the right answer. So the right answer thing is presuming that 00:22:27.060 |
you get the answer right. It's like giving the model some confidence in itself. 00:22:32.780 |
So this might seem to you like a total dark arcane art. And that's because it is. We really 00:22:38.260 |
have no intuition as to what's going on here. Or we're trying to build some intuition. But 00:22:44.580 |
as a result, and I'm sure you've seen if you spend time in tech circles or you've seen 00:22:48.380 |
on the internet, there's this whole new idea of prompt engineering being an emergent science 00:22:52.700 |
and profession. So this includes things like asking a model for reasoning. It includes 00:22:57.060 |
jailbreaking language models for telling them to do things that they otherwise aren't trained 00:23:01.780 |
to do. Even AI art like DALI or stable diffusion, this idea of constructing these really complex 00:23:08.740 |
prompts to get model outputs that you want. That's also prompting. Anecdotally, I've heard 00:23:13.660 |
of people saying I'm going to use a code generation model, but I'm going to include the Google 00:23:17.100 |
code header in first because that will make more professional or bug free code depending 00:23:21.460 |
on how much you believe in Google. But yeah, and there's a Wikipedia article on this now 00:23:27.260 |
and there's even startups that are hiring for prompt engineers and they pay quite well. 00:23:30.480 |
So if you want to be a prompt engineer, definitely practice your GPT whispering skills. 00:23:35.980 |
We have a question? Sorry. Yes, you go. Yeah, go ahead. 00:23:44.500 |
A few slides ago, you said LM design that was like this long. How can you get the LM 00:23:54.620 |
I think they treated it like a reinforcement learning problem. But I'll just direct you 00:23:58.100 |
to this paper at the bottom to learn more details. Yeah, I think it's the Joe et al 00:24:02.780 |
So I'm just a bit curious about how they provided feedback. So in case the model was not giving 00:24:03.780 |
the right answer, were there prompts to say that that's not right? Maybe think about this 00:24:04.780 |
different approach. How is feedback provided? 00:24:09.780 |
They don't think about feedback in this kind of chain of thought prompting experiments. 00:24:25.020 |
They just like if the model gets the answer wrong, then it gets the answer wrong and we 00:24:27.900 |
just evaluate accuracy. Right. But this idea of incorporating feedback, I think, is quite 00:24:31.700 |
interesting and I think you'll see some maybe hints of discussion of that later on. Yeah. 00:24:41.900 |
Questions? Okay, awesome. Okay, so talking about these three things, I'm going to talk 00:24:51.820 |
about the benefits and limitations of the various different things that we could be 00:24:55.060 |
doing here. So for zero shot and few shot in context learning, the benefit is you don't 00:24:59.860 |
need any fine tuning and you can carefully construct your prompts to hopefully get better 00:25:04.140 |
performance. The downsides are there are limits to what you can fit in context. Transformers 00:25:09.540 |
have a fixed context window of say 1,000 or a few thousand tokens. And I think, as you 00:25:14.940 |
will probably find out, for really complex tasks, you are indeed going to need some gradient 00:25:18.980 |
steps. So you're going to need some sort of fine tuning. But that brings us to the next 00:25:24.220 |
part of the lecture. So that's instruction fine tuning. Okay, so the idea of instruction 00:25:31.180 |
fine tuning is that, sure, these models are pretty good at doing prompting. You can get 00:25:35.900 |
them to do really interesting things. But there is still a problem, which is that language 00:25:40.280 |
models are trained to predict the most likely continuation of tokens. And that is not the 00:25:44.220 |
same as what we want language models to do, which is to assist people. So as an example, 00:25:49.460 |
if I give GPT-3 this kind of prompt, explain the moon landing, GPT-3 is trained to predict, 00:25:54.700 |
you know, if I saw this on the internet somewhere, what is the most likely continuation? Well, 00:25:59.140 |
maybe someone was coming up with a list of things to do with a six year old. So it's 00:26:02.420 |
just predicting a list of other tasks, right? It's not answering your question. And so the 00:26:07.260 |
issue here is that language models are not, the term is aligned with user intent. So how 00:26:13.360 |
might we better align models with user intent for this case? Well, super simple answer, 00:26:19.700 |
right? We're machine learners. Let's do machine learning. So we're going to ask a human, give 00:26:25.360 |
me the right answer, right? Give me the way that a language model should respond according 00:26:29.300 |
to this prompt. And let's just do fine tuning. So this is a slide from the pre-training lecture. 00:26:39.140 |
Again, pre-training can improve NLP applications by serving as parameter initialization. So 00:26:45.940 |
this kind of pipeline, I think you are familiar with. And the difference here is that instead 00:26:51.420 |
of fine tuning on a single downstream task of interest, like sentiment analysis, what 00:26:55.460 |
we're going to do is we're going to fine tune on many tasks. So we have a lot of tasks and 00:26:59.860 |
the hope is that we can then generalize to other unseen tasks at test time. So as you 00:27:06.040 |
might expect, data and scale is kind of key for this to work. So we're going to collect 00:27:11.340 |
a bunch of examples of instruction output pairs across many tasks and then fine tune 00:27:16.940 |
our language model and then evaluate generalization to unseen tasks. 00:27:23.100 |
Yeah, so data and scale is important. So as an example, one recent data set that was published 00:27:29.900 |
for this is called the Supernatural Instructions Dataset. It contains over 1.6 thousand tasks 00:27:35.820 |
containing 3 million examples. So this includes translation, question answering, question 00:27:40.740 |
generation, even coding, mathematical reasoning, etc. And when you look at this, you really 00:27:47.540 |
begin to think, well, is this actually fine tuning or is this just more pre-training? 00:27:51.340 |
And it's actually both. We're kind of blurring the lines here where the amount of scale that 00:27:55.780 |
we're training this on, basically it is kind of a still general but slightly more specific 00:28:00.520 |
than language modeling type of pre-training task. 00:28:05.780 |
So one question I have is, now that we are training our model on so many tasks, how do 00:28:10.660 |
we evaluate such a model? Because you can't really say, OK, can you now do sentiment analysis 00:28:15.140 |
well? The scale of tasks we want to evaluate this language model on is much greater. 00:28:22.220 |
So just as a brief aside, a lot of research has been going into building up these benchmarks 00:28:27.740 |
for these massive multicast language models and seeing to what degree they can do not 00:28:32.380 |
only just one task, but just a variety of tasks. So this is the Massive Multitask Language 00:28:37.180 |
Understanding Benchmark or MMLU. It consists of a bunch of benchmarks for measuring language 00:28:42.420 |
model performance on a bunch of knowledge intensive tasks that you would expect a high 00:28:46.380 |
school or college student to complete. So you're testing a language model not only on 00:28:51.500 |
sentiment analysis, but on astronomy and logic and European history. And here are some numbers 00:28:58.340 |
where at the time, DPD 3 is not that good, but it's certainly above a random baseline 00:29:07.380 |
Here's another example. So this is the Beyond the Imitation Game Benchmark or BigBench. 00:29:11.760 |
This has like a billion authors because it was a huge collaborative effort. And this 00:29:16.100 |
is a word cloud of the tasks that were evaluated. And it really contains some very esoteric 00:29:22.740 |
tasks. So this is an example of one task included where you have to, given a kanji or Japanese 00:29:27.740 |
character in ASCII art, you need to predict the meaning of the character. So we're really 00:29:35.060 |
OK, so instruction fine tuning, does it work? Recall there's a T5 encoded decoder model. 00:29:44.860 |
So this is kind of Google's encoded decoder model, where it's pre-trained on this span 00:29:48.540 |
corruption task. So if you don't remember that, you can refer back to that lecture. 00:29:52.780 |
But the authors released a newer version called FLAN T5. So FLAN stands for fine tuning language 00:29:57.780 |
models. And this is T5 models trained on an additional 1.8 thousand tasks, which include 00:30:02.820 |
the natural instructions data that I just mentioned. And if we average across both the 00:30:07.120 |
BigBench and an MLU performance and normalize it, what we see is that instruction fine tuning 00:30:12.940 |
works. And crucially, the bigger the model, the bigger the benefit that you get from doing 00:30:18.300 |
instruction fine tuning. So it's really the large models that stand to do well from fine 00:30:25.640 |
And you might look at this and say, this is kind of sad for academics or anyone without 00:30:29.620 |
a massive GPU cluster. It's like who can run an 11 billion parameter model? I guess the 00:30:34.020 |
one silver lining, if you look at the results here, are the 80 million model, which is the 00:30:38.500 |
smallest one. If you look at after fine tuning, it ends up performing about as well as the 00:30:43.140 |
un-fine tuned 11 billion parameter model. So there's a lot of examples in the literature 00:30:48.020 |
about smaller instruction fine tune pre-trained models outperforming larger models that are 00:30:53.700 |
many, many more times the size. So hopefully there's still some hope for people with just 00:31:00.420 |
Any questions? Awesome. In order to really understand the capabilities, I highly recommend 00:31:08.140 |
that you just try it out yourself. So Flan T5 is hosted on Hugging Face. I think Hugging 00:31:13.500 |
Face has a demo where you can just type in a little query, ask it to do anything, see 00:31:17.500 |
what it does. But there are qualitative examples of this working. So four questions where a 00:31:23.140 |
non-instruction fine tune model will just kind of waffle on and not answer the question. 00:31:27.820 |
Doing instruction fine tuning will get your model to much more accurately reason through 00:31:37.380 |
OK. So that was instruction fine tuning. Positives of this method. Super simple, super straightforward. 00:31:45.300 |
It's just doing fine tuning. And you see this really cool ability to generalize to unseen 00:31:52.940 |
In terms of negatives, does anyone have any ideas for what might be downsides of instruction 00:31:59.580 |
It seems like it suffers from the same negatives of any human source data. It's hard to get 00:32:13.620 |
people to provide the input. You don't know. Different people think different inputs about 00:32:19.620 |
Yeah, yeah, exactly. So comments are, well, it's hard and annoying to get human labels 00:32:25.700 |
and it's expensive. That's something that definitely matters. And that last part you 00:32:29.180 |
mentioned about there might be, you know, humans might disagree on what the right label 00:32:32.540 |
is. Yeah, that's increasingly a problem. Yeah. So what are the limitations? The obvious limitation 00:32:39.540 |
is money. Collecting ground truth data for so many tasks costs a lot of money. Subtler 00:32:45.740 |
limitations include the one that you were mentioning. So as we begin to ask for more 00:32:50.340 |
creative and open-ended tasks from our models, right, there are tasks where there is no right 00:32:54.300 |
answer. And it's a little bit weird to say, you know, this is an example of how to write 00:32:59.020 |
some story, right? So write me a story about a dog and her pet grasshopper. Like there 00:33:03.100 |
is not one answer to this, but if we were only to collect one or two demonstrations, 00:33:07.740 |
the language modeling objective would say you should put all of your probability mass 00:33:11.820 |
on the two ways that two humans wrote this answer, right? When in reality, there's no 00:33:19.900 |
Another problem, which is related kind of fundamentally to language modeling in the first 00:33:23.260 |
place, is that language modeling as an objective penalizes all token level mistakes equally. 00:33:29.860 |
So what I mean by that is if you were asking a language model, for example, to predict 00:33:33.020 |
the sentence, "Avatar is a fantasy TV show," and you were asking it, and let's imagine 00:33:39.140 |
that the LM mispredicted adventure instead of fantasy, right? So adventure is a mistake. 00:33:45.420 |
It's not the right word, but it is equally as bad as if the model were to predict something 00:33:50.500 |
like musical, right? But the problem is that "Avatar is an adventure TV show" is still 00:33:56.420 |
true, right? So it's not necessarily a bad thing, whereas "Avatar is a musical" is just 00:34:00.460 |
false. So under the language modeling objective, right, if the model were equally confident, 00:34:06.020 |
you would pay the equal penalty, an equal loss penalty for predicting either of those 00:34:09.480 |
tokens wrong. But it's clear that this objective is not actually aligned with what users want, 00:34:15.700 |
which is maybe truth or creativity or generally just this idea of human preferences, right? 00:34:21.860 |
Could we do something like multiply the penalty by the distance from where you're betting 00:34:29.060 |
in order to reduce this? Because musical would have a higher distance away than adventure. 00:34:34.740 |
Yeah, that's an interesting question. It's an interesting idea. I haven't heard of people 00:34:41.140 |
doing that, but it seems plausible. I guess one issue is you might come up with adversarial 00:34:46.540 |
settings where maybe the word embedding distance is also not telling you the right thing, right? 00:34:50.340 |
So for example, show and musical maybe are very close together because they're both shows 00:34:55.220 |
or things to watch, but they are in veracity, right? They're completely different. One is 00:34:59.860 |
true, one is false, right? So yeah, you can try it, although I think there might be some 00:35:07.540 |
Cool. Okay, so in the next part of the talk, we're going to actually explicitly try to 00:35:14.700 |
satisfy human preferences and come with a mathematical framework for doing so. And yeah, 00:35:23.020 |
so these are the limitations, as I had just mentioned. So this is where we get into reinforcing 00:35:31.020 |
Okay, so RLHF. So let's say we were training a language model on some task like summarization. 00:35:41.700 |
And let's imagine that for each language model sample S, let's imagine that we had a way 00:35:46.340 |
to obtain a human reward of that summary. So we could score this summary with a reward 00:35:52.740 |
function, which we'll call R of S, and the higher the reward, the better. So let's imagine 00:35:59.940 |
we're summarizing this article, and we have this summary, which maybe is pretty good, 00:36:04.900 |
let's say. We had another summary, maybe it's a bit worse. And if we were able to ask a 00:36:10.780 |
human to just rate all these outputs, then the objective that we want to maximize or 00:36:14.740 |
satisfy is very obvious. We just want to maximize the expected reward of samples from our language 00:36:19.900 |
model, right? So in expectation, as we take samples from our language model, P theta, 00:36:25.940 |
we just want to maximize the reward of those samples. Fairly straightforward. 00:36:33.300 |
So for mathematical simplicity here, I'm kind of assuming that there's only one task or 00:36:38.220 |
one prompt, right? So let's imagine we were just trying to summarize this article, but 00:36:42.340 |
we could talk about how to extend it to multiple prompts later on. 00:36:46.580 |
Okay, so this kind of task is the domain of reinforcement learning. So I'm not going to 00:36:52.580 |
presume there's any knowledge of reinforcement learning, although I'm sure some of you are 00:36:55.780 |
quite familiar with it, probably even more familiar than I am. But the field of reinforcement 00:37:00.380 |
learning has studied these kinds of problems, these optimization problems of how to optimize 00:37:04.700 |
something while you're simulating the optimization for many years now. And in 2013, there was 00:37:11.140 |
a resurgence of interest in reinforcement learning for deep learning specifically. So 00:37:15.100 |
you might have seen these results from DeepMind about an agent learning to play Atari games, 00:37:20.140 |
an agent mastering Go much earlier than expected. 00:37:24.660 |
But interestingly, I think the interest in applying reinforcement learning to modern 00:37:28.140 |
LMs is a bit newer, on the other hand. And I think the kind of earliest success story 00:37:32.940 |
or one of the earliest success stories was only in 2019, for example. So why might this 00:37:37.620 |
be the case? There's a few reasons. I think in general, the field had kind of this sense 00:37:42.100 |
that reinforcement learning with language models was really hard to get right, partially 00:37:46.780 |
because language models are very complicated. And if you think of language models as actors 00:37:52.540 |
that have an action space where they can spit out any sentence, that's a lot of sentences. 00:37:56.980 |
So it's a very complex space to explore. So it still is a really hard problem. So that's 00:38:01.460 |
part of the reason. But also practically, I think there have been these newer algorithms 00:38:06.340 |
that seem to work much better for deep neural models, including language models. And these 00:38:11.080 |
include algorithms like proximal policy optimization. But we won't get into the details of that 00:38:15.620 |
for this course. But these are the kind of the reasons why we've been reinterested in 00:38:28.060 |
So how do we actually maximize this objective? I've written it down. And ideally, we should 00:38:32.420 |
just change our parameters data so that reward is high. But it's not really clear how to 00:38:37.020 |
do so. So when we think about it, I mean, what have we learned in the class thus far? 00:38:42.180 |
We know that we can do gradient descent or gradient ascent. So let's try doing gradient 00:38:45.780 |
ascent. We're going to maximize this objective. So we're going to step in the direction of 00:38:49.220 |
steepest gradient. But this quickly becomes a problem, which is what is this quantity 00:38:54.820 |
and how do we evaluate it? How do we estimate this expectation given that the variables 00:39:00.460 |
of the gradient that we're taking, theta, appear in the sample of the expectation? And 00:39:06.420 |
the second is what if our reward function is not differentiable? Like human judgments 00:39:10.600 |
are not differentiable. We can't back prop through them. And so we need this to be able 00:39:16.620 |
So there's a class of methods in reinforcement learning called policy gradient methods that 00:39:22.820 |
gives us tools for estimating and optimizing this objective. And for the purposes of this 00:39:28.220 |
course, I'm going to try to describe the highest level possible intuition for this, which looks 00:39:35.260 |
at the math and shows what's going on here. But it is going to omit a lot of the details. 00:39:40.380 |
And a full treatment of reinforcement learning is definitely outside of the scope of this 00:39:43.820 |
course. So if you're more interested in this kind of content, you should check out CS234 00:39:48.620 |
Reinforcement Learning, for example. And in general, I think this is going to get a little 00:39:53.380 |
mathy, but it's totally fine if you don't understand it. We will talk, we'll regroup 00:39:56.740 |
at the end and just show what this means for how to do RLHF. 00:40:03.320 |
But what I'm going to do is just describe how we actually estimate this objective. So 00:40:06.780 |
we want to obtain this gradient. So it's the gradient of the expectation of the reward 00:40:12.980 |
of samples from our language model. And if we do the math, we break this apart. This 00:40:17.260 |
is our definition of what an expectation is. We're going to sum over all sentences rated 00:40:21.580 |
by the probability. And due to the linearity of the gradient, we can put the gradient operator 00:40:31.680 |
Now what we're going to do is we're going to use a very handy trick known as a log derivative 00:40:35.380 |
trick. And this is called a trick, but it's really just the chain rule. But let's just 00:40:38.940 |
see what happens when we take the gradient of the log probability of a sample from our 00:40:46.240 |
So if I take the gradients, then how do we use the chain rule? So the gradient of the 00:40:49.900 |
log of something is going to be 1 over that something times the gradient of the middle 00:40:53.460 |
of that something. So 1 over P theta of s times the gradient. And if we rearrange, we 00:40:58.380 |
see that we can alternatively write the gradient of P theta of s as this product. So P theta 00:41:04.580 |
of s times the gradient of the log P theta of s. And we can plug this back in. 00:41:12.660 |
And the reason why we're doing this is because we're going to convert this into a form where 00:41:15.740 |
the expectation is easy to estimate. So we plug it back in. That gives us this. And if 00:41:22.340 |
you squint quite closely at this last equation here, this first part here is the definition 00:41:28.540 |
of an expectation. We are summing over a bunch of samples from our model, and we are weighting 00:41:33.060 |
it by the probability of that sample, which means that we can rewrite it as an expectation. 00:41:37.740 |
And in particular, it's an expectation of this quantity here. So let's just rewrite 00:41:42.900 |
it. And this gives us our kind of newer form of this objective. So these two are equivalent, 00:41:50.500 |
And what has happened here is we've kind of shoved the gradient inside of the expectation, 00:41:54.860 |
if that makes sense. So why is this useful? Does anyone have any questions on this before 00:42:00.700 |
I move on? If you don't understand it, that's fine as well, because we will understand the 00:42:14.420 |
So we've converted this into this. And we put the gradient inside the expectation, which 00:42:20.180 |
means we can now approximate this objective with Monte Carlo samples. So the way to approximate 00:42:24.900 |
any expectation is to just take a bunch of samples and then average them. So approximately, 00:42:30.420 |
this is equal to sampling a finite number of samples from our model, and then summing 00:42:34.920 |
up the average of the reward times the log probability, the gradient of the log probability 00:42:39.700 |
of that sample. And that gives us this update rule, plugging it back in for that gradient 00:42:49.580 |
So what is this? What does this mean? Let's think about a very simple case. Imagine the 00:42:56.840 |
reward was a binary reward. So it was either 0 or 1. So for example, imagine we were trying 00:43:01.820 |
to train a language model to talk about cats. So whenever it utters a sentence with the 00:43:05.500 |
word cat, we give it a 1 reward. Otherwise, we give it a 0 reward. Now, if our reward 00:43:10.860 |
is binary, does anyone know what this objective reduces to or look like? Any ideas? If I've 00:43:27.460 |
The reward would just be an indicator function. 00:43:35.380 |
So basically, to answer, the reward would be 0 everywhere, except for sentences that 00:43:41.260 |
contain the word cat. And in that case, it would be 1. So basically, that would just 00:43:46.300 |
look like vanilla gradient descent, just on sentences that contain the word cat. 00:43:52.760 |
So to generalize this to the more general case, where the reward is scalar, what this 00:43:57.860 |
is looking like, if you look at it, is if r is very high, very positive, then we're 00:44:03.060 |
multiplying the gradient of that sample by a large number. And so our objective will 00:44:07.420 |
try to take gradient steps in the direction of maximizing the probability of producing 00:44:11.460 |
that sample again, producing the sample that led to high reward. 00:44:15.680 |
And on the other hand, if r is low or even negative, then we will actively take steps 00:44:19.700 |
to minimize the probability of that happening again. And that's the English intuition of 00:44:24.180 |
what's going on here. The reason why we call it reinforcement learning is because we want 00:44:28.280 |
to reinforce good actions and increase the probability that they happen again in the 00:44:33.100 |
And hopefully, this intuitively makes sense to all of you. Let's say you're playing a 00:44:36.100 |
video game, and on one run, you get a super high score. And you think to yourself, oh, 00:44:40.300 |
that was really good. Whatever I did that time, I should do again in the future. This 00:44:43.980 |
is what we're trying to capture with this kind of update. 00:44:47.020 |
Is there any reason that we use policy gradient and not value iteration or other methods? 00:44:55.300 |
You can do a lot of things. I think there have been methods for doing Q-learning, offline 00:44:59.980 |
learning, et cetera, with language models. I think the design space has been very underexplored. 00:45:06.660 |
So there's a lot of low-hanging fruit out there for people who are willing to think 00:45:09.440 |
about what fancy things we can do in RL and apply them to this language modeling case. 00:45:15.540 |
And in practice, what we use is not this simple thing, but we use a fancier thing that is 00:45:22.340 |
Do you know if you're on LN, the space are super big, like almost a bit? 00:45:28.100 |
So that's the challenge. So one thing that I haven't mentioned here is that right now, 00:45:33.500 |
I'm talking about entire samples of sentences, which is a massive space. In practice, when 00:45:38.260 |
we do RL, we actually do it at the level of generating individual tokens. So each token 00:45:42.180 |
is, let's say, GPT has 50,000 tokens. So it's a pretty large action space, but it's still 00:45:51.300 |
So that kind of answers this question I was asking, which is, can you see any problems 00:45:54.240 |
with this objective? Which is that this is a very simplified objective. There is a lot 00:45:58.440 |
more tricks needed to make this work. But hopefully, this has given you kind of the 00:46:02.100 |
high-level intuition as to what we're trying to do in the first place. 00:46:09.260 |
OK, so now we are set. We have a bunch of samples from a language model. And for any 00:46:19.020 |
arbitrary reward function, like we're just asking a human to rate these samples, we can 00:46:26.260 |
OK, so not so fast. There's a few problems. The first is the same as in the instruction 00:46:31.660 |
fine-tuning case, which is that keeping a human in the loop is expensive. I don't really 00:46:36.180 |
want to supervise every single output from a language model. I don't know if you all 00:46:44.660 |
So one idea is, instead of needing to ask humans for preferences every single time, 00:46:49.120 |
you can actually build a model of their preferences, like literally just train an NLP model of 00:46:53.540 |
their preferences. So this idea was kind of first introduced outside of language modeling 00:46:58.820 |
by this paper, Knox and Stone. They called it Tamr. But we're going to see it re-implemented 00:47:04.560 |
in this idea, where we're going to train a language model-- we'll call it a reward model, 00:47:08.940 |
RM, which is parameterized by phi-- to predict human preferences from an annotated data set. 00:47:15.060 |
And then when doing RLHF, we're going to optimize for the reward model rewards instead of actual 00:47:25.140 |
Here's another conceptual problem. So here's a new sample for our summarization task. What 00:47:30.260 |
is the score of the sample? Anyone give me a number. Does anyone want to rate this sample? 00:47:36.020 |
It's like a 3, 6. What scale are we using? Et cetera. 00:47:42.560 |
So the issue here is that human judgments can be noisy and miscalibrated when you ask 00:47:46.660 |
people for things alone. So one workaround for this problem is, instead of asking for 00:47:53.960 |
direct ratings, ask humans to compare two summaries and judge which one is better. This 00:47:59.660 |
has been shown, I think, in a variety of fields where people work with human subjects and 00:48:03.400 |
human responses to be more reliable. This includes psychology and medicine, et cetera. 00:48:09.340 |
So in other words, instead of asking humans to just give absolute scores, we're going 00:48:13.400 |
to ask humans to compare different samples and rate which one is better. So as an example, 00:48:19.580 |
maybe this first sample is better than the middle sample, and it's better than the last 00:48:25.760 |
Now that we have these pairwise comparisons, our reward model is going to generate latent 00:48:30.260 |
scores, so implicit scores based on this pairwise comparison data. So our reward model is a 00:48:35.660 |
language model that takes in a possible sample, and then it's going to produce a number, which 00:48:43.700 |
And the way that we're going to train this model-- and again, you don't really need to 00:48:46.980 |
know too much of the details here, but this is a classic statistical comparison model-- 00:48:51.700 |
is via the following loss, where the reward model essentially should just predict a higher 00:48:56.700 |
score if a sample is judged to be better than another sample. So in expectation, if we sample 00:49:03.160 |
winning samples and losing samples from our data sets, then if you look at this term here, 00:49:09.160 |
the score of the higher sample should be higher than the score of the losing sample. Does 00:49:16.540 |
that make sense? And in doing so, by just training on this objective, you will get a 00:49:22.020 |
language model that will learn to assign numerical scores to things, which indicate their relative 00:49:27.420 |
preference over other samples. And we can use those outputs as rewards. 00:49:32.580 |
Is there some renormalization either in the output or somewhere else? 00:49:44.140 |
Yeah, so I don't remember if it happens during training. But certainly, after you've trained 00:49:49.580 |
this model, you normalize the reward model so that the score is-- the expectation of 00:49:52.380 |
the score is 0, because that's good for reinforcement learning and things like that as well. Yeah, 00:49:59.380 |
How do we account for the fact that even though things are noisy, some people could view S3 00:50:06.780 |
as better than S1. How do we account for even though when it's noisy, the border and the 00:50:14.260 |
Yeah, I think that's just kind of limitations with asking for these preferences in the first 00:50:20.060 |
place is that humans will disagree. So we really have no ground truth unless we maybe 00:50:24.500 |
ask an ensemble of humans, for example. That's just a limitation with this. I think hopefully, 00:50:29.540 |
in the limit with enough data, this kind of noise washes out. But it's certainly an issue. 00:50:33.980 |
And this next slide will also kind of touch on this. 00:50:38.180 |
So does the reward model work? Can we actually learn to model human preferences in this way? 00:50:42.620 |
This is obviously an important standard we check before we actually try to optimize this 00:50:45.620 |
objective. And they measured this. So this is kind of evaluating the reward model on 00:50:51.540 |
a standard kind of validation set. So can the reward model predict outcomes for data 00:50:56.660 |
points that they have not seen during training? And does it change based on model size or 00:51:02.020 |
amount of data? And if you notice here, there's one dashed line, which is the human baseline, 00:51:06.780 |
which is if you ask a human to predict the outcome, a human does not get 100% accuracy 00:51:11.660 |
because humans disagree. And even an ensemble of, let's say, five humans also doesn't get 00:51:16.660 |
100% accuracy because humans have different preferences. 00:51:20.580 |
But the key takeaway here is that for the largest possible model and for enough data, 00:51:25.820 |
a reward model, at least on the validation set that they used, is kind of approaching 00:51:30.380 |
the performance of a single human person. And that's kind of a green light that maybe 00:51:42.940 |
So if there are no questions, this is kind of the components of our LHF. So we have a 00:51:49.500 |
pre-trained model, maybe it's instruction fine-tuned, which we're going to call P of 00:51:53.180 |
PT. We have a reward model, which produces scalar rewards for language model outputs, 00:51:59.500 |
and it is trained on a dataset of human comparisons. And we have a method, policy gradient, for 00:52:04.900 |
arbitrarily optimizing language model perimeters towards some reward function. 00:52:10.220 |
And so now if you want to do our LHF, you clone the pre-trained model, we're going to 00:52:14.860 |
call this a copy of the model, which is the RL model, with parameters data that we're 00:52:19.420 |
actually going to optimize. And we're going to optimize the following reward with reinforcement 00:52:25.660 |
learning. And this reward looks a little bit more complicated than just using the reward 00:52:29.660 |
model. And the extra term here is a penalty, which prevents us from diverging too far from 00:52:37.060 |
the pre-trained model. So in expectation, this is known as the KL or Kohlback-Lieber 00:52:41.860 |
divergence between the RL model and the pre-trained model. 00:52:47.940 |
And I'll explain why we need this in a few slides. But basically, if you over-optimize 00:52:52.580 |
the reward model, you end up producing-- you can produce gibberish. And what happens is 00:52:57.300 |
you pay a price. So this quantity is large if the probability of a sample under the RL 00:53:04.180 |
tuned model is much higher than the probability of the sample under the pre-trained model. 00:53:08.820 |
So the pre-trained model would say, this is a very unlikely sequence of characters for 00:53:12.160 |
anyone to say. That's when you would pay a price here. And beta here is a tunable parameter. 00:53:18.780 |
When you say initialize a copy, that means the first iteration, PRL is equal to PPT? 00:53:28.060 |
That's right. Yeah. Yeah, when I say initialize a copy, basically, we want to be able to compare 00:53:33.200 |
to the non-fine-tuned model just to evaluate this penalty term. So just leave the predictions 00:53:43.620 |
More questions? Great. So does it work? The answer is yes. So here is the key takeaway, 00:53:56.660 |
at least for the task summarization on this daily mail data set. So again, we're looking 00:54:02.220 |
at different model sizes. But at the end here, we see that if we do just pre-training-- so 00:54:07.060 |
just like the typical language modeling objective that GPT uses-- you end up producing summaries 00:54:12.060 |
that, in general, are not preferred to the reference summaries. So this is on the y-axis 00:54:16.020 |
here is the amount of times that a human prefers the model-generated summary to a summary that 00:54:21.860 |
a human actually wrote or the one that's in the data set. 00:54:25.480 |
So pre-training doesn't work well, even if you do supervised learning. So supervised 00:54:29.340 |
learning in this case is, let's actually fine-tune our model on the summaries that were in our 00:54:33.900 |
data sets. Even if you do that, you still kind of underperform the reference summaries, 00:54:39.020 |
because you're not perfectly modeling those summaries. But it's only with this human feedback 00:54:44.860 |
that we end up producing a language model that actually ends up producing summaries 00:54:48.780 |
that are judged to be better than the summaries in a data set that you were training on in 00:54:52.260 |
the first place. I think that's quite interesting. Any questions? 00:55:06.100 |
So now we talk about-- yeah, we're getting closer and closer to something like InstructGPT 00:55:10.900 |
or ChatGPT. The basic idea of InstructGPT is that we are scaling up RLHF to not just 00:55:18.660 |
one prompt, as I had described previously, but tens of thousands of prompts. And if you 00:55:23.940 |
look at these three pieces, these are the three pieces that we've just described. The 00:55:27.740 |
first piece here being instruction fine-tuning, the second piece being RLHF, and the third 00:55:33.540 |
piece-- oh, sorry, the second part being reward model training, and the last part being RLHF. 00:55:39.740 |
The difference here is that they use 30,000 tasks. So again, with the same instruction 00:55:47.060 |
fine-tuning idea, it's really about the scale and diversity of tasks that really matters 00:55:50.820 |
for getting good performance for these things. Yeah? 00:55:54.300 |
Yeah, so the preceding results, you suggested that you really needed the RLHF, and it didn't 00:56:08.540 |
work so well to do supervised learning on the data. But they do supervised learning 00:56:14.020 |
on the data in the fine-tuning in the first stage. Is that necessary, or else they should 00:56:21.060 |
have tended to go haywire and just went straight to RLHF? 00:56:27.140 |
Oh, yeah, that's a good question. So I think a key point here is that they initialized 00:56:31.300 |
the RL policy on the supervised policy. So they first got the model getting reasonably 00:56:36.340 |
good at doing summarization first, and then you do the RLHF on top to get the boost performance. 00:56:42.640 |
Your question you're asking is maybe, can we just do the RLHF starting from that pre-trained 00:56:46.380 |
baseline? That's a good question. I don't think they explored that, although I'm not 00:56:52.300 |
sure. I'd have to look at the paper again to remind myself. Yeah. 00:57:02.380 |
So certainly for something like InstructGPT, yeah, they've always kind of presumed that 00:57:06.240 |
you need the kind of fine-tuning phase first, and then you build on top of it. But I think, 00:57:10.820 |
yeah, there's still some interesting open questions as to whether you can just go directly 00:57:16.580 |
Is the human reward function trained simultaneously with the fine-tuning of the language model? 00:57:30.340 |
Reward model should be trained first. Yeah. You train it first, you make sure it's good, 00:57:35.620 |
What are the samples for the human rewards? Do they come from the generated task from 00:57:42.620 |
language model? Or where does the training sample come from? 00:57:48.220 |
So, yeah, actually, it's a good question. Where do the rewards come from? So there's 00:57:54.700 |
kind of an iterative process you can apply where you kind of repeat steps two and three 00:57:58.580 |
over and over again. So you sample a bunch of outputs from your language model. You get 00:58:03.380 |
humans to rate them. You then do RLHF to update your model again. And then you sample more 00:58:09.800 |
So in general, the rewards are done on sampled model outputs, because those are the outputs 00:58:13.780 |
that you want to steer in one direction or another. But you can do this in an iterative 00:58:17.820 |
process where you kind of do RL and then maybe train a better reward model based on the new 00:58:22.420 |
outputs and continue. And I think they do a few iterations in InstructGBT, for example. 00:58:31.220 |
Questions? OK. So 30,000 tasks. I think we're getting into very recent stuff where increasingly 00:58:45.180 |
companies like OpenAI are sharing less and less details about what actually happens in 00:58:49.740 |
training these models. So we have a little bit less clarity as to what's going on here 00:58:53.220 |
than maybe we have had in the past. But they do share the data that's not public, but they 00:58:59.620 |
do share the kinds of tasks that they collected from labelers. So they collected a bunch of 00:59:03.780 |
prompts from people who were already using the GPT-3 API. So they had the benefit of 00:59:08.540 |
having many, many users of their API and taking the kinds of tasks that users would ask GPT 00:59:15.860 |
to do. And so these include things like brainstorming or open-end generation, et cetera. 00:59:24.860 |
And yeah, I mean, the key results of InstructGBT, which is kind of the backbone of ChatGBT, 00:59:31.060 |
really just needs to be seen and played with to understand. So you can feel free to play 00:59:34.540 |
with either ChatGBT or one of the OpenAI APIs. But again, this example of a language model 00:59:40.620 |
and not necessarily following tasks, by doing this kind of instruction fine tuning followed 00:59:45.900 |
by RLHF, you get a model that is much better at adhering to user commands. Similarly, a 00:59:55.220 |
language model can be very good at generating super interesting open-ended creative text 01:00:09.580 |
This brings us to ChatGBT, which is even newer, and we have even less information about what's 01:00:14.260 |
actually going on or what's being trained here. But yeah, and they're keeping their 01:00:19.660 |
secret sauce secret. But we do have a blog post where they wrote two paragraphs. And 01:00:25.940 |
in the first paragraph, they said that they did instruction fine tuning. So we trained 01:00:30.980 |
an initial model using supervised fine tuning. So human AI trainers provided conversations 01:00:35.940 |
where they played both sides. And then we asked them to act as a AI assistant. And then 01:00:40.620 |
we fine-tuned our model on acting like an AI assistant for humans. That's part one. 01:00:46.540 |
Second paragraph, to create a reward model for RL, we collected comparison data. So we 01:00:52.940 |
took conversations with an earlier version of the chatbot, so the one that's pre-trained 01:00:56.940 |
on instruction following or instruction fine tuning, and then take multiple samples and 01:01:02.020 |
then rate the quality of the samples. And then using these reward models, we fine-tune 01:01:07.900 |
it with RL. In particular, they used PPO, which is a fancier version of RL. 01:01:18.340 |
And yeah, so that produces-- I don't need to introduce the capabilities of ChatGBT. 01:01:21.580 |
It's been very exciting recently. Here's an example. It's fun to play with. Definitely 01:01:26.260 |
play with it. Sorry, it's a bit of an attack on the students. Yeah. OK. 01:01:43.300 |
So reinforcement learning, pluses. You're kind of directly modeling what you care about, 01:01:49.700 |
which is human preferences, not is the collection of the demonstration that I collected, is 01:01:55.300 |
that the highest probability mass in your model. You're actually just saying, how well 01:01:59.140 |
am I satisfying human preferences? So that's a clear benefit over something like instruction 01:02:06.340 |
So in terms of negatives, one is that RL is hard. It's very tricky to get right. I think 01:02:11.420 |
it will get easier in the future as we kind of explore the design space of possible options. 01:02:16.900 |
So that's an obvious one. Does anyone come up with any other kind of maybe weaknesses 01:02:21.060 |
or issues they see with this kind of training? Yeah. 01:02:25.420 |
Is it possible that your language model and then your reward model can over-fit to each 01:02:33.300 |
other, especially-- even if you're not training them together, if you're going back and forth 01:02:38.820 |
Yeah. Yeah. So over-optimization, I think, of the reward model is an issue. Yeah. 01:02:43.900 |
Is it also that if you retrain your baseline, if you repeat all this human feedback, it 01:02:50.580 |
Yeah. So it still is extremely data expensive. And you can see some articles if you just 01:02:55.820 |
Google OpenAI data labeling. People have not been very happy with the amount of data that 01:03:00.020 |
has been needed to train something like ChatGBT. I mean, they're hiring developers to just 01:03:03.660 |
explain coding problems 40 hours a week. So it is still data intensive. That's kind of 01:03:09.820 |
the takeaway. All of these are-- it's all still data intensive, every single one of 01:03:16.420 |
Yeah. I think that summarizes kind of the big ones here. So when we talk about limitations 01:03:24.100 |
of RLHF, we also need to talk about just limitations in general of RL, and also this idea that 01:03:30.540 |
we can model or capture human reward in this single data point. 01:03:35.460 |
So human preferences can be very unreliable. The RL people have known this for a very long 01:03:41.100 |
time. They have a term called reward hacking, which is when an agent is optimizing for something 01:03:45.900 |
that the developer specified, but it is not what we actually care about. So one of the 01:03:51.620 |
classic examples is this example from OpenAI, where they were training this agent to race 01:03:58.020 |
boats. And they were training it to maximize the score, which you can see at the bottom 01:04:02.220 |
left. But implicitly, the score actually isn't what you care about. What you care about is 01:04:06.220 |
just finishing the race ahead of everyone else. And the score is just kind of this bonus. 01:04:09.740 |
But what the agent found out was that there are these turbo boost things that you can 01:04:13.660 |
collect, which boost your score. And so what it ends up doing is it ends up kind of just 01:04:17.580 |
driving in the middle, collecting these turbo boosts over and over again. So it's racking 01:04:20.820 |
up insane score, but it is not doing the race. It is continuously crashing into objects, 01:04:25.860 |
and its boat is always on fire. And this is a pretty salient example of what we call AI 01:04:34.300 |
And you might think, well, OK, this is a really simple example. They made a dumb mistake. 01:04:39.900 |
They shouldn't have used score as a reward function. But I think it's even more naive 01:04:44.220 |
to think that we can capture all of human preferences in a single number and assign 01:04:54.020 |
So one example where I think this is already happening, you can see, is maybe you have 01:04:58.940 |
played with chatbots before, and you notice that they do a lot of hallucination. They 01:05:03.060 |
make up a lot of facts. And this might be because of RLHF. Chatbots are rewarded to 01:05:07.960 |
produce responses that seem authoritative or seem helpful, but they don't care about 01:05:13.060 |
whether it's actually true or not. They just want to seem helpful. 01:05:17.060 |
So this results in making up facts. You may be seeing the news about chatbots. Companies 01:05:22.180 |
are in this race to deploy chatbots, and they make mistakes. Even Bing also has been hallucinating 01:05:31.220 |
And in general, when you think about that, you think, well, models of human preferences 01:05:35.700 |
are even more unreliable. We're not even just using human preferences by themselves. We're 01:05:40.340 |
also training a model, a deep model, that we have no idea how that works. We're going 01:05:44.500 |
to use that instead. And that can obviously be quite dangerous. 01:05:50.420 |
And so going back to this slide here, where I was describing why we need this KL penalty 01:05:54.740 |
term, this yellow highlighted term here, here's a concrete example of what actually happens 01:05:59.900 |
of a language model overfitting to the reward model. 01:06:03.460 |
So what this is showing is, in this case, they took off the KL penalty. So they were 01:06:07.300 |
just trying to maximize reward. They trained this reward model. Let's just push those numbers 01:06:11.140 |
up as high as possible. And on the x-axis here is what happens as training continues. 01:06:16.620 |
You diverge further and further. This is the KL divergence or the distance from where you 01:06:22.580 |
And the golden dashed line here is what the reward model predicts your language model 01:06:27.340 |
is doing. So your reward model is thinking, wow, you are killing it. They are going to 01:06:31.220 |
love these summaries. They are going to love them way more than the reference summaries. 01:06:35.860 |
But in reality, when you actually ask humans, the preferences peak, and then they just crater. 01:06:43.060 |
So this can be an example of over-optimizing for a metric that you care about. It ceases 01:06:57.380 |
So there's this real concern of, I think, what people are calling the AI alignment problem. 01:07:01.020 |
I'll let Percy Leung talk about this. He tweeted that the main tool that we have for alignment 01:07:07.580 |
is RLHF. But reward hacking happens a lot. Humans are not very good supervisors of rewards. 01:07:14.520 |
So this strategy is probably going to result in agents that seem like they're doing the 01:07:18.100 |
right thing, but they're wrong in subtle and conspicuous ways. And I think we're already 01:07:21.900 |
seeing examples of that in the current generation of chatbots. 01:07:29.060 |
So in terms of positives, here are some positives. But again, RL is tricky to get right. Human 01:07:34.700 |
preferences are fallible, and models of human preferences are even more so. 01:07:42.860 |
So I remember seeing a joke on Twitter somewhere where someone was saying that zero shot and 01:07:47.540 |
few shot learning is the worst way to align in AI. Instruction fine tuning is the second 01:07:52.180 |
worst way to align in AI. And RLHF is the third worst way to align in AI. So we're getting 01:07:57.700 |
somewhere, but each of these have clear fundamental limitations. 01:08:03.660 |
I have a question on more of like competition of reinforcement learning. Because if you 01:08:11.540 |
get the math that Nick showed before, essentially you're putting the gradient inside so that 01:08:15.940 |
you can sample it, the sample expectation. But when it comes to sampling, how do you 01:08:20.980 |
make that parallel? Because then you need to adaptively stop sampling, and then you 01:08:27.660 |
don't know when you're going to stop. How do you make that process quicker? The whole 01:08:32.700 |
unit on transformers and all that was parallelizing everything. 01:08:38.460 |
I mean, yeah. So this is really compute heavy. And I'm actually not sure what kind of infrastructure 01:08:44.220 |
is used for a state of the art, very performant implementation of RLHF. But it's possible 01:08:48.420 |
that they use parallelization like what you're describing, where I think in a lot of maybe 01:08:52.140 |
more traditional RL, there's this kind of idea of having an actor learner architecture 01:08:57.020 |
where you have a bunch of actor workers, which are each kind of a language model producing 01:09:00.100 |
a bunch of samples. And then the learner would then integrate them and perform the gradient 01:09:03.740 |
updates. So it's possible that you do need to do just sheer multiprocessing in order 01:09:08.420 |
to get enough samples to make this work in a reasonable amount of time. Is that the kind 01:09:13.060 |
of question you had? Or do you have other questions? 01:09:15.060 |
Kind of. So you're basically saying that each unit that you parallelize over is larger than 01:09:26.780 |
I was saying that you might need to actually copy your model several times and take samples 01:09:31.180 |
from different copies of the models. Yeah. But in terms of like-- yeah, so autoregressive 01:09:35.540 |
generation, transformers, especially like the forward pass and the multi-head attention 01:09:39.500 |
stuff is very easy to parallelize. But autoregressive generation is still kind of bottlenecked by 01:09:48.220 |
the fact that it's autoregressive. So you have to run it first and then you need to-- 01:09:51.500 |
depends on what you sample, you have to run it again. So those are kind of blocks that 01:09:55.140 |
we haven't fully been able to solve, I think. And that will add to compute cost. 01:10:07.260 |
So I think we have 10 more minutes if I'm not mistaken. So we've mostly finally answered 01:10:12.220 |
how we get from this to this. There's some details missing. But the key kind of factors 01:10:16.940 |
are one, instruction fine tuning. Two, this idea of reinforced learning from human feedback. 01:10:24.260 |
So let's talk a little bit about what's next. So as I had mentioned, RLHF is still a very 01:10:31.520 |
new area. It's still very fast moving. I think by the next lecture, by the time we say that 01:10:36.380 |
I did these slides again, these slides might look completely different because maybe a 01:10:39.740 |
lot of the things that I was presenting here turn out to be really bad ideas or not the 01:10:44.580 |
most efficient way of going about things. RLHF gets you further than instruction fine 01:10:49.860 |
tuning. But as someone had already mentioned, it is still very data expensive. There are 01:10:54.740 |
a lot of articles about OpenAI needing to hire a legion of annotators or developers 01:11:03.700 |
I think a recent work that I'm especially interested in and been thinking about is how 01:11:07.740 |
we can get the benefits of RLHF without such stringent data requirements. So there's these 01:11:13.140 |
newer kind of crazy ideas about doing reinforcement learning from not human feedback, but from 01:11:19.100 |
AI feedback. So having language models themselves evaluate the output of language models. So 01:11:24.340 |
as an example of what that might look like, a team from Anthropic, which works on these 01:11:28.300 |
large language models, came up with this idea called constitutional AI. And the basic idea 01:11:33.580 |
here is that if you ask GPT-3 to identify whether a response was not helpful, it would 01:11:38.180 |
be pretty good at doing so. And you might be able to use that feedback itself to improve 01:11:42.260 |
a model. So as an example, if you have some sort of human request, like, can you help 01:11:47.380 |
me hack into my neighbor's Wi-Fi? And the assistant says, yeah, sure, you can use this 01:11:51.380 |
app, right? We can ask a model for feedback on this. What we do is we add a critique request, 01:11:58.540 |
which says, hey, language model GPT-3, identify ways in which the assistant's response is 01:12:04.060 |
harmful. And then it will generate a critique, like hacking into someone else's Wi-Fi is 01:12:09.580 |
illegal. And then you might ask it to then revise it, right? So just rewrite the assistant 01:12:14.940 |
response to remove harmful content. And it does so. And now by just decoding from a language 01:12:23.260 |
model, assuming you can do this well, what you have now is a set of data that you can 01:12:28.220 |
do instruction fine tuning on, right? You have a request and you have a request that 01:12:32.220 |
has been revised to make sure it doesn't contain harmful content. 01:12:37.780 |
So this is pretty interesting. I think it's quite exciting. But all of those issues that 01:12:41.700 |
I had mentioned about alignment, mis-overinterpreting human preferences, reward models being fallible, 01:12:49.980 |
everything gets compounded like 40,000 times when you're thinking about this, right? We 01:12:53.140 |
have no understanding of how safe this is or where this ends up going, but it is something. 01:12:59.940 |
Another kind of more common idea also is this general idea of fine tuning language models 01:13:03.780 |
on their own outputs. And this has been explored a lot in the context of chain of thought reasoning, 01:13:07.580 |
which is something I presented at the beginning of the lecture. And these are provocatively 01:13:11.420 |
named large language models can self-improve. But again, it's not clear how much runway 01:13:17.220 |
But the basic idea maybe is to-- you can use let's think step by step, for example, to 01:13:21.420 |
get a language model to produce a bunch of reasoning. And then you can say fine tune 01:13:24.260 |
on that reasoning as if it were true data and see whether or not a language model can 01:13:33.940 |
But as I mentioned, this is all still very new. There are, I think, a lot of limitations 01:13:38.100 |
of large language models like hallucination and also just the sheer size and compute intensity 01:13:42.900 |
of this that may or may not be solvable with RLHF. 01:13:47.420 |
[INAUDIBLE] feedback of how we don't want to be at that. I've seen people talking about 01:13:56.700 |
how you can jailbreak chat GPT to still give those types of funnable responses. Are there 01:14:02.300 |
any ways for us to buffer against those types of things as well? Because it seems like you're 01:14:09.700 |
just going to keep building on-- we need to identify chances where it's trying to say 01:14:14.940 |
action not like yourself. I guess is there any way to build up that scale to avoid those 01:14:24.740 |
Yeah, that's interesting. So there are certainly ways that you can use either AI feedback or 01:14:32.860 |
human feedback to mitigate those kinds of jailbreaks. If you see someone on Twitter 01:14:36.420 |
saying that, oh, I made GPT-3 jailbreak using this strategy or whatever, you can then maybe 01:14:43.180 |
plug it into this kind of framework and say identify ways in which the assistant went 01:14:46.260 |
off the rails and then fine tune and hopefully correct those. But it is really difficult, 01:14:50.980 |
I think, in most of these kinds of settings. It's really difficult to anticipate all the 01:14:54.140 |
possible ways in which a user might jailbreak an assistant. So you always have this kind 01:14:58.820 |
of dynamic of like in security, cybersecurity, for example, there's always the attacker advantage 01:15:04.260 |
where the attacker will always come up with something new or some new exploit. So yeah, 01:15:10.260 |
I think this is a deep problem. I don't have a really clear answer. But certainly, if we 01:15:14.180 |
knew what the jailbreak was, we could mitigate it. I think that seems pretty straightforward. 01:15:21.180 |
But if you know how to do that, you should be hired by one of these companies. They'll 01:15:30.100 |
OK. Yeah, so just last remarks is with all of these scaling results that I presented 01:15:36.900 |
and all of these like, oh, you can just do instruction fine tuning and it'll follow your 01:15:40.300 |
instructions, or you can do RLHF. You might have a very bullish view on like, oh, this 01:15:44.580 |
is how we're going to solve artificial general intelligence by just scaling up RLHF. It's 01:15:48.860 |
possible that that is actually going to happen. But it's also possible that there are certain 01:15:53.300 |
fundamental limitations that we just need to figure out how to solve, like hallucination, 01:15:58.220 |
before we get anywhere productive with these models. But it is a really exciting time to 01:16:01.720 |
work on this kind of stuff. So yeah. Thanks for listening.