back to index

RL for Autonomous Coding — Aakanksha Chowdhery, Reflection.ai


Chapters

0:0 Introduction to LLMs and Scaling Laws
1:41 Emergent Behavior in LLMs
4:0 Reinforcement Learning from Human Feedback (RLHF)
6:11 Inference-Time Scaling and Verification
10:33 Challenges with Inference-Time Scaling
11:16 The Next Frontier: Reinforcement Learning for Correct Generation
13:20 Challenges in Scaling RL
14:58 Autonomous Coding as a Prime Domain for RL
15:53 Reflection.ai's Mission

Whisper Transcript | Transcript Only Page

00:00:00.000 | Hi, everyone. I'm Akhan Shaw. I was at Google for more than six years, and I led the research
00:00:21.360 | for Palm, and I was a lead researcher in Gemini. These days, I'm working on pushing the frontier
00:00:27.120 | for Autonomous Coding with Reinforcement Learning. So just to recap the arc of how we have progressed
00:00:35.840 | in large language models and why Autonomous Coding and why now. So I think everyone here,
00:00:44.640 | or those of you who don't remember, in 2020, there was this breakthrough paper that came out,
00:00:50.000 | which talked about scaling laws for large language models. And if you were to take a 30-second recap,
00:00:56.160 | all the main thing it said was that there's a power law relationship between the test loss of large
00:01:01.920 | language models. So if you use more compute, more data, and put more parameters in your machine
00:01:08.400 | learning model, which is a transformer model, you will get more performant models. And it will not
00:01:15.680 | be performant just in the domain in which you are training the model. It will actually be performant,
00:01:20.160 | and it will generalize to many other domains. And the generalization was pretty much a feature in this
00:01:27.360 | particular case. So as the large language models got bigger, we saw continuous improvement across
00:01:34.320 | benchmarks to the point that they're starting to get saturated now. And the other interesting thing was
00:01:40.320 | that we saw emergent behavior where capabilities were emerging in large language models that were not
00:01:46.240 | present in smaller models. And this is a classic slide that I show for the work that we did in Palm.
00:01:54.160 | So typically when you go about trying to solve math problems and you give the model some examples,
00:02:00.720 | on the left you have a math problem around tennis balls, and then you give a second problem,
00:02:06.160 | the model output looks wrong. But what Palm and the subsequent set of papers showed was that
00:02:12.880 | if you ask the model to output its reasoning chains, which has become a very common concept now, but
00:02:19.440 | this is, remember, 2021, so four years ago, if you ask the model to show its reasoning chains, then
00:02:26.080 | the answer actually is correct. So basically by getting the model to output its chain of thought
00:02:32.400 | or reasoning chains, the model performance improves. And this capability particularly emerged in large
00:02:39.920 | language models. These are all the models. So Lambda and Palm were the state-of-the-art models about three
00:02:45.840 | years ago. And what I'm showing on x-axis is the increasing number of parameters. Palm was scaled all
00:02:52.160 | the way up to 540 billion parameters. No one actually publishes the number of parameters these days, so
00:02:57.840 | you have to live with the graphs from three years ago or the open source stuff that's coming out with
00:03:03.040 | DeepSeq and Quen models. But what y-axis is showing is that the solve rate on middle school math
00:03:09.600 | word problems was increasing with the number of parameters in the models. And it was essentially
00:03:16.160 | increasing mainly when you are prompting the models and asking them to show chain of thought. And this
00:03:22.560 | led to all kinds of prompting techniques where you ask the model to think step by step. You even go and
00:03:27.280 | bribe the model and such, and you ask the model nicely or not. So this was all kinds of fun stuff. And I think
00:03:35.120 | the thing that really stood out from this generation of models few years ago was that this capability
00:03:41.680 | capability was not just limited to math problems. It was basically generalizing across
00:03:48.240 | a whole bunch of domains anywhere from question answering in other languages to puzzle problems to
00:03:54.240 | multitask natural language understanding problems. And what this led to next was that
00:04:02.240 | now that these models could reason, we could get them to follow instructions. So the first set of
00:04:08.800 | applications that became possible with these large language models were chatbot applications. So everyone
00:04:14.480 | remembers that ChatGPT and now Gemini and various other chatbots have become extremely popular. All of us
00:04:21.840 | use them all the time. But what made them really possible was that when you give instructions to the model
00:04:27.360 | to go do something, it's actually able to do it. And the way it learns that is actually based on reinforcement
00:04:33.280 | learning. And the reinforcement learning data that we're giving to the model in this particular case
00:04:38.720 | is essentially data based on human feedback. So you're basically saying, okay, here is a set of questions.
00:04:46.960 | And if I were to give it to a human and it were there were two answers, which one would the human
00:04:53.200 | prefer? And if you have enough of this data and you train your model, you would actually end up with
00:04:58.960 | a better performance because you taught the model which set of responses to prefer. And this actually
00:05:05.120 | doesn't only work in chatbot applications, it also works in code. So on the bottom right, I'm showing that
00:05:11.360 | even if you were to do this for applications in code, you start to see some performance improvements.
00:05:16.960 | Now, of course, the question is that last year, there was a whole bunch of debate as to are we
00:05:23.200 | hitting the wall in terms of performance of large language models, pre-training is not giving any gains,
00:05:28.560 | or all of these questions were on the horizon. So what is next? And one of the key questions to remember
00:05:37.200 | in all of this is that when you go and pre-train the models, you end up spending a lot of money
00:05:42.160 | on training these models. It could be tens of millions of dollars. And when you do inference on
00:05:47.920 | the models, it's extremely cheap. These numbers are not endorsed by any of the companies I worked at,
00:05:54.800 | but these are public numbers from public sources. So going back to the main point that I want to make
00:06:01.760 | here is that training is extremely costly. So if you constantly try to scale up the model size,
00:06:07.920 | you end up in this regime of like, if it's not giving performance gains, then can we get performance
00:06:15.120 | gains at inference time because inference calls are so cheap? And a key idea that was extremely useful
00:06:24.480 | here was that if you could get the models to generate multiple responses and then do majority voting. So
00:06:32.640 | in the example above, I'm showing that the prompt doesn't make sense, but you've given a mathematical
00:06:40.160 | problem to large language model and you're asking it to generate three answers independently. And then you
00:06:45.760 | basically do some voting on top of those answers. And if two answers match, then that's a majority vote. Or
00:06:53.360 | like if in this room I were to ask a question and all of you said, yes, then that is a majority vote.
00:06:58.720 | So similarly in large language models, if you can get the model to like generate many, many samples and
00:07:04.240 | then consistently get it to, uh, like get many of those answers to agree, this notion of majority voting
00:07:10.640 | or self consistency had shown gains. So this kind of scaling computed inference time was clearly one
00:07:16.240 | avenue to go push on. Another avenue that emerged and showed substantial value was that you could
00:07:22.640 | sequentially revise your previous response. So as humans, oftentimes we write the first answer and then
00:07:29.680 | we go evaluate our answer and we're like, oh, there's some mistake here. It doesn't quite match. And then
00:07:34.560 | you go fix it. So basically can we get LLMs to do the same kind of revision looking at previous set of
00:07:41.360 | revisions? And this was the second. So basically having longer, uh, chains of thought, uh, and getting
00:07:46.640 | the model to improve consistently in inference time based on that. And these kind of techniques, uh, where
00:07:52.960 | you could verify your correct answer. So in math or in programming where you have unit tests showed, uh,
00:07:58.960 | very clear gains. So what I'm showing you here is an example from, uh, uh, one of my colleagues
00:08:04.400 | work, uh, at Stanford, uh, which is a publicly, uh, published, uh, paper. And on the y-axis, we have
00:08:11.760 | pass at K or coverage score. And on the x-axis, we have a number of samples. So as you basically are
00:08:17.920 | doing a lot of samples on the x-axis, your accuracy is improving, uh, with open source DeepSeq model and
00:08:23.840 | just taking more samples. So you're getting a very high score on Sweebench verified compared to even
00:08:29.600 | state of the art back in end of 2024. Uh, of course, now all of these scores have pushed up and we are
00:08:36.000 | roughly somewhere around 80% already. But what we want to take away here is the fact that these lines
00:08:45.200 | of work, they showed that inference time compute predictably gives us gains, especially in domains
00:08:51.600 | where we can verify. If we know how to verify the answers, then we actually know how to translate
00:08:58.800 | that into intelligence. And going back to my talk title, coding is one of those domains where we
00:09:04.720 | do have the capability to verify. Um, and that gives us tremendous advantage in terms of
00:09:11.200 | building super intelligence on top of autonomous coding.
00:09:16.160 | Of course, now you ask the question of what does automated verification mean here?
00:09:20.800 | So for inference time scaling to work, you need basically some way to say this output is correct.
00:09:28.720 | Now in math, um, this is a very simple example. If you were to give the input to solve this mathematical
00:09:34.560 | equation, um, and if you were to do the same calculation on a calculator, you can actually verify that
00:09:40.320 | that that problem is correct or that solution is correct. Um, and similarly in math, you have formal
00:09:47.920 | proofs so you can actually verify that things are correct. Uh, in coding, you have unit tests in
00:09:52.960 | compilers. You can actually generate the code and then use Pytorch as a verifier, uh, Pytorch the compiler
00:09:58.800 | as a verifier. And in fact, uh, in domains where you don't have, uh, this kind of verification,
00:10:05.040 | then there's a large gap. If you were to generate a lot of solutions and then do majority
00:10:10.240 | voting, you actually don't get as much gains. So what this roughly meant was that, okay,
00:10:15.600 | so inference time scaling would work in scenarios where I have automated verification,
00:10:20.640 | but that doesn't quite solve the problem for it to have real world impact. And the reason for that
00:10:26.880 | is shown in this graph as to typically, uh, if you do majority voting and these, uh, this is across
00:10:33.520 | multiple different models on GSM 8K, which is middle school math problems and another math benchmark. Uh, if you were to
00:10:40.160 | sort them by correct fraction, you have to sample a lot. The correct generations could be very rare.
00:10:45.920 | So who has time to sample 10,000 times and then get a correct solution? You would be sitting there waiting,
00:10:51.920 | just finding the correct solution unless you can actually figure out where the correct generation is.
00:10:58.960 | So basically scaling inference time compute with just majority voting or longer reasoning chains is great
00:11:05.840 | in the sense that there are some correct solutions somewhere there, but it doesn't work well across
00:11:10.560 | the board. So what will get these models to learn to generate correctly, uh, during training? Well,
00:11:18.480 | in the chat bot application scenario, we saw that RL with human feedback did work. So can we apply the
00:11:25.040 | same principle here and get the model to generate correctly in, uh, where, where we can automatically
00:11:31.120 | verify the outputs? So our belief at, uh, reflection is that the next frontier for scaling is reinforcement
00:11:38.000 | learning. And we already have proof points from some of the frontier labs as well.
00:11:42.800 | And as David Silver, um, uh, sudden published recently, they agree with, uh, or, or rather they, they are the pioneers
00:11:50.800 | in, uh, in reinforcement learning. They say that we are basically entering the era of experience, like starting
00:11:57.280 | from alpha go and alpha zero, uh, where you had an era of simulation and the next set of large language model
00:12:04.800 | era was where you scaled up with RL using human data. But the next era from this year is really the era of
00:12:12.240 | experience, which was, which will lead us to super intelligence. So reinforcement learning will be a
00:12:16.960 | fundamental component in building, uh, super intelligent systems, uh, especially in areas where we have
00:12:23.280 | automated, uh, verification. And some, uh, proof point for why this makes sense is that in math, uh, over
00:12:32.000 | several papers, this is, uh, results from 01, but over several papers, we have already seen examples that if you
00:12:39.280 | give the model on the right side, uh, test time compute, uh, on the Y axis, test time compute the same as
00:12:45.440 | inference time scaling and you measure accuracy on the X axis, it should go up. Um, but as you can
00:12:52.320 | repeat this process, uh, and with the reinforcement learning, then the training time compute going up on
00:12:58.160 | X axis also improves the accuracy on Y axis for a challenging benchmark in math called Amy. Uh, most
00:13:04.480 | of these benchmarks saturate within a year as you probably have learned by now. So, uh, this benchmark
00:13:09.760 | is already saturated. Um, so now that I've hopefully convinced you that reinforcement learning and
00:13:18.000 | scaling reinforcement learning is, uh, the next frontier, you'd be like, okay, so why are, why is not everyone
00:13:24.960 | doing it? What's so challenging about it? So as I have built large language models before,
00:13:29.920 | a big part of building, uh, these systems, uh, ends up being that the machine learning plus system stack
00:13:36.160 | for these, uh, systems themselves is very challenging. So here is, um, an example of, uh, why scaling up
00:13:44.000 | reinforcement learning is challenging. So if you are trying to do reinforcement learning with, uh, PPO,
00:13:49.440 | which is, uh, one of the, uh, algorithms used for RL with human feedback, um, then it moved to, uh, DPO,
00:13:56.480 | you have to keep four copies of, uh, different models. Uh, so if you imagine a really large model
00:14:03.120 | and then you have to keep four copies, then you have to arrange them somewhere on GPUs in your large
00:14:08.080 | cluster. You, you can have some fun figuring out the exact layout and, um, it's, it's, it's a fun and
00:14:14.800 | interesting problem, but it's a hard problem in the sense that, uh, to make maximum utilization of
00:14:20.080 | these systems and, and arranging them in the right way, just building that system is extremely hard.
00:14:25.440 | And, uh, deep seek actually showed, uh, with deep seek math that GRPO, uh, gets rid of the value model
00:14:31.200 | and it only has three copies of the model, but that doesn't, that's still a very challenging problem.
00:14:36.080 | So scaling up RL, uh, is more, even more challenging, um, than scaling up, um, LLMs because you have multiple
00:14:44.800 | copies of the model and you have a training loop and an inference loop. And then on the machine learning
00:14:50.080 | side on the, on the reinforcement learning side, you also suffer a lot from reward hacking. If you're,
00:14:55.280 | uh, the model that is deciding that this is the correct answer is a neural reward model. So you,
00:15:00.720 | uh, as we discussed before in autonomous coding applications, you do have the ability to verify
00:15:06.400 | your output, uh, which roughly means that you can decide this is the correct answer or not. Uh,
00:15:12.560 | that's how sweet bench verified scores, uh, work today. Uh, you have execution feedback,
00:15:18.320 | you have unit tests. So all of these possibilities, of course, um, this is an ongoing list. All of these
00:15:24.800 | possibilities may mean that you can design better reward functions. Okay. So this means that autonomous
00:15:31.680 | coding is a great domain for scaling up RL. Then the question becomes, how does this have real world impact?
00:15:37.680 | So in software engineering applications, generation of code is only one part of the system. If you
00:15:44.240 | look at end to end workflows for software engineering, there is many more parts to that system. How do you
00:15:49.440 | scale up your system to generalize across all of those domains? So that's the problem we are trying to
00:15:55.200 | solve at reflection. Uh, our mission is that we would like to build some super intelligence and we are
00:16:00.880 | starting with autonomous coding as the root node problem for this, um, a mission. And, uh, we have
00:16:08.240 | a team of about 35 pioneers, um, who are, who have pioneered various, uh, legendary works in LLMs and
00:16:16.160 | reinforcement learning. So if you're excited about this mission, uh, you can reach out to, um, one of us,
00:16:22.880 | Um, or my, uh, my emails, uh, my last name at reflection.ai and we would love to work with you. And with that, I can take
00:16:29.440 | questions.
00:16:37.280 | All right. Um, same protocol as last time. If you have a question, please come up to one of these
00:16:41.840 | three microphones we have distributed throughout. We can probably take one or two questions. So if you
00:16:46.400 | want to ask something, um, feel free. Um, I guess I can, I I'll do the first one while people are coming
00:16:51.440 | up. So I'm curious, um, it seems like the foundation models are trying to build one model and deploy it
00:16:59.040 | across everything. Do you have an opinion with the work you're doing right now? If you think that's the right
00:17:03.600 | approach or if you think there'll be more specialization on different languages or even
00:17:07.680 | like individual code bases, um, or do you feel like the best approach is just to have like one model
00:17:12.160 | that's trained across the, the greatest diversity of tasks possible? Uh, I think I will answer your
00:17:16.880 | question, uh, in terms of building coding agents does require, um, multiple capabilities and how you get
00:17:23.680 | there, you will definitely need multiple LLM calls and then whether that's one model or multiple models,
00:17:28.800 | I think that's the secret sauce right now for most people. Fair enough. All right.
00:17:33.200 | Please. Hi. Um, I'm wondering in the slide with the chart of error of simulation, error of something
00:17:40.880 | and error of experience, uh, they had put in AlphaGo and, um, the previous one where also you, they played
00:17:50.480 | Star, Starcraft or something. They all used MCDS, uh, which I mean, maybe it's my unfamiliarity with them,
00:17:57.680 | but it's also data simulation. Uh, so we're using synthetic data for error of experience as well. So
00:18:05.600 | how does, why is that called simulation and why is what we're doing right now not called simulation?
00:18:11.520 | What's the sort of overlap between simulation experience? How does that, how do you think about that?
00:18:15.920 | I can ask Dave that question, you know, but going back to the point, I think, I think the better way to
00:18:22.560 | answer that question is, uh, roughly what Greg covered in the last talk where his comment was that, um, so in
00:18:28.800 | gaming you can envision what scenarios might happen next and you're basically using that to build your
00:18:35.760 | reinforcement learning. So you're doing rollouts and you're, you're basically building, uh, based on that,
00:18:40.320 | uh, in, in real world, in most scenarios, you have an imperfect rollout. So you don't have full knowledge
00:18:48.720 | of how the system might works. Um, simulation is possible in certain domains where you do build a
00:18:56.160 | world model, uh, which is closer to robotics and all the work that's happening in the physical AI space,
00:19:02.400 | right. But in the, in the real world applications, which is what we're targeting, uh, you will have
00:19:08.560 | imperfect things. So you have to actually experience the real world and you have to collect some data
00:19:12.960 | and that data is not going to be in any way complete, nor will it complete early search,
00:19:18.320 | the exponential search space that could exist.