[00:00:00.000 --> 00:00:02.580] (upbeat music) [00:00:02.580 --> 00:00:06.360] - Hello, hello, this is Swix with a special edition [00:00:06.360 --> 00:00:09.420] of the Delay in Space pod for NeurIPS 2023. [00:00:09.420 --> 00:00:12.760] Both Alessio and I were there covering what we could cover. [00:00:12.760 --> 00:00:15.520] It is an impossible conference, 15,000 people, [00:00:15.520 --> 00:00:19.040] 3,500 papers and tons and tons of sessions. [00:00:19.040 --> 00:00:20.800] So it's just impossible for two people to cover it, [00:00:20.800 --> 00:00:24.360] especially with a limited time, but we did our best. [00:00:24.360 --> 00:00:26.900] A lot of you liked our OpenAI Dev Day coverage [00:00:26.900 --> 00:00:29.320] where we basically just jumped from paper to paper, [00:00:29.320 --> 00:00:31.980] person to person, founder to founder and got their takes. [00:00:31.980 --> 00:00:34.280] And this is effectively what we've tried to do here. [00:00:34.280 --> 00:00:36.420] It's still experimental and new format for us. [00:00:36.420 --> 00:00:37.840] So we really love your feedback. [00:00:37.840 --> 00:00:39.620] We're actually doing a listener survey now. [00:00:39.620 --> 00:00:40.800] If you click into the show notes, [00:00:40.800 --> 00:00:42.240] we'd really love to hear your feedback [00:00:42.240 --> 00:00:45.140] and know what you wanna hear for 2024. [00:00:45.140 --> 00:00:47.200] So we recorded a lot of audio in NeurIPS [00:00:47.200 --> 00:00:49.260] and I figured the most logical way to cover this [00:00:49.260 --> 00:00:51.520] would be to start with the best papers. [00:00:51.520 --> 00:00:53.480] NeurIPS does hand out best paper awards. [00:00:53.480 --> 00:00:56.120] So we're gonna start with the hardest one to obtain, [00:00:56.120 --> 00:00:57.480] which is the Test of Time Award. [00:00:57.480 --> 00:00:59.320] The Test of Time Award is given to a paper [00:00:59.320 --> 00:01:01.400] that has stood the test of time, [00:01:01.400 --> 00:01:03.800] which by NeurIPS's definition is a paper [00:01:03.800 --> 00:01:06.040] that was published 10 years ago at NeurIPS. [00:01:06.040 --> 00:01:07.880] NeurIPS is in its 37th year. [00:01:07.880 --> 00:01:09.320] So this is honestly a flex [00:01:09.320 --> 00:01:11.440] that very, very few conferences can actually do. [00:01:11.440 --> 00:01:13.440] And it's really interesting to have the original authors [00:01:13.440 --> 00:01:15.920] of the paper come back and talk about what they've learned [00:01:15.920 --> 00:01:18.040] and how they look back at the past 10 years. [00:01:18.040 --> 00:01:19.960] So here's Jeff Dean and Greg Corrado. [00:01:19.960 --> 00:01:21.120] (audience applauding) [00:01:21.120 --> 00:01:22.160] - Thank you very much. [00:01:22.160 --> 00:01:23.840] I'm Jeff. [00:01:23.840 --> 00:01:24.920] - And I'm Greg. [00:01:24.920 --> 00:01:26.720] - And we're here to give a little talk [00:01:26.720 --> 00:01:28.360] and a retrospective on this work. [00:01:28.360 --> 00:01:31.080] So this work actually started out [00:01:31.080 --> 00:01:34.000] as an ICLR 2013 workshop paper [00:01:34.000 --> 00:01:37.040] with four of our co-authors working together. [00:01:37.040 --> 00:01:38.900] And in that work, we sort of explored [00:01:38.900 --> 00:01:43.080] a bunch of different sort of loss functions and techniques [00:01:43.080 --> 00:01:46.420] for optimizing word embedding representations. [00:01:46.420 --> 00:01:51.780] And really that was kind of the genesis of this work. [00:01:51.780 --> 00:01:55.080] And that work was cited by quite a few people. [00:01:55.080 --> 00:01:57.680] And one of the things that we discovered in that work [00:01:57.680 --> 00:02:00.080] was that the skip-gram model, [00:02:00.080 --> 00:02:02.360] one of the few models that we evaluated [00:02:02.360 --> 00:02:03.600] in this workshop paper, [00:02:03.600 --> 00:02:05.640] really was showing better performance [00:02:05.640 --> 00:02:06.960] than some of the other ones that we worked on. [00:02:06.960 --> 00:02:08.460] So we decided to focus on that [00:02:08.460 --> 00:02:12.600] and really focus on the skip-gram model [00:02:12.600 --> 00:02:15.660] and then some interesting sort of optimization techniques [00:02:15.660 --> 00:02:19.040] to improve the optimization of the word embeddings [00:02:19.040 --> 00:02:22.180] and added the ability to do phrase embeddings as well. [00:02:22.180 --> 00:02:25.260] And along the way, Ilya joined us as a co-author, [00:02:25.260 --> 00:02:26.100] which was great. [00:02:26.100 --> 00:02:30.540] And this paper has been cited by a number of people, [00:02:30.540 --> 00:02:32.160] as Sergei mentioned. [00:02:32.160 --> 00:02:34.020] One thing we've discovered, [00:02:34.020 --> 00:02:36.420] including source code and trained representations [00:02:36.420 --> 00:02:37.960] really does boost your citation count. [00:02:37.960 --> 00:02:41.020] People have done this and used these [00:02:41.020 --> 00:02:42.980] downstream representations for all kinds of things [00:02:42.980 --> 00:02:45.500] and we're very gratified to see that in the community. [00:02:47.340 --> 00:02:50.360] And we also wanna highlight that three of our co-authors [00:02:50.360 --> 00:02:51.200] couldn't make it today. [00:02:51.200 --> 00:02:53.760] So Tomas, Ilya, and Kai couldn't be here, [00:02:53.760 --> 00:02:58.160] but on their behalf, we're delighted to be giving this talk. [00:02:58.160 --> 00:03:02.480] And with that, I'm gonna turn it over to Greg, I think. [00:03:02.480 --> 00:03:04.760] Oh no, we're older now, sorry. [00:03:04.760 --> 00:03:07.960] Sadly, we found more recent photos [00:03:07.960 --> 00:03:12.440] and this is a test of time, Ward, and time has passed. [00:03:12.440 --> 00:03:14.040] - Yes, I think we survived. [00:03:14.040 --> 00:03:15.600] We survived the test, mostly. [00:03:16.480 --> 00:03:19.260] But so let's stand back and ask ourselves, [00:03:19.260 --> 00:03:22.540] what did we really learn from these papers? [00:03:22.540 --> 00:03:24.700] But before I get into that, [00:03:24.700 --> 00:03:27.460] I should probably stipulate [00:03:27.460 --> 00:03:29.700] that some of you out there rightfully say, [00:03:29.700 --> 00:03:31.660] well, we already believed these things [00:03:31.660 --> 00:03:33.140] before you published this work. [00:03:33.140 --> 00:03:35.440] And so for you, maybe this is really us [00:03:35.440 --> 00:03:37.280] reinforcing these points. [00:03:37.280 --> 00:03:39.860] Other of you might think that, [00:03:39.860 --> 00:03:42.260] well, the paper didn't really exactly prove this point, [00:03:42.260 --> 00:03:43.140] it just suggested it. [00:03:43.140 --> 00:03:45.340] So it foreshadowed it. [00:03:45.340 --> 00:03:46.740] We don't have any quarrel [00:03:46.740 --> 00:03:50.380] with whether it was reinforcing or shadowing or learning, [00:03:50.380 --> 00:03:51.820] and so we'll just put that aside [00:03:51.820 --> 00:03:52.900] for the remainder of the talk [00:03:52.900 --> 00:03:56.140] and talk about what we think are at least the themes [00:03:56.140 --> 00:03:58.980] that were in this work that resonate today. [00:03:58.980 --> 00:04:03.340] So the first point is that semi-supervised objectives [00:04:03.340 --> 00:04:06.580] have an incredibly powerful opportunity [00:04:06.580 --> 00:04:08.300] and we think that they're gonna be critical [00:04:08.300 --> 00:04:11.000] for natural language understanding going forward. [00:04:12.300 --> 00:04:15.500] We think that this paper shows that fast parallel [00:04:15.500 --> 00:04:19.740] and weakly supervised synchronization and computation [00:04:19.740 --> 00:04:23.400] really dominates over the sort of fruitless precision [00:04:23.400 --> 00:04:24.880] of tight synchronization. [00:04:24.880 --> 00:04:29.380] Focusing compute where it really helps [00:04:29.380 --> 00:04:31.460] and improves your learning of representations [00:04:31.460 --> 00:04:32.900] is what's most important. [00:04:32.900 --> 00:04:37.680] And tokenization can be used as a good trick [00:04:37.680 --> 00:04:39.760] to solve some nuanced problems. [00:04:39.760 --> 00:04:42.940] And then the last and I think most important point [00:04:42.940 --> 00:04:47.220] is that treating language as a sequence of dense vectors [00:04:47.220 --> 00:04:48.660] has proven to be really powerful [00:04:48.660 --> 00:04:52.540] and honestly more powerful than I think we imagined [00:04:52.540 --> 00:04:54.580] when we started this work. [00:04:54.580 --> 00:04:58.100] So first on semi-supervised objectives, [00:04:58.100 --> 00:05:00.100] why is this so important? [00:05:00.100 --> 00:05:03.700] Of course, almost all machine learning systems today [00:05:03.700 --> 00:05:06.100] go through some period of supervised learning. [00:05:06.100 --> 00:05:07.780] We're always gonna use that, [00:05:07.780 --> 00:05:10.080] but there's too much to learn in the world [00:05:10.080 --> 00:05:12.680] to use supervised learning for everything. [00:05:12.680 --> 00:05:14.760] The promise of unsupervised learning, of course, [00:05:14.760 --> 00:05:16.800] is tantalizing but has been difficult [00:05:16.800 --> 00:05:18.140] to implement in practice. [00:05:18.140 --> 00:05:19.880] And so semi-supervised learning, [00:05:19.880 --> 00:05:24.280] the ability to construct a supervised feeling data [00:05:24.280 --> 00:05:28.000] from a dataset from an unlabeled corpus [00:05:28.000 --> 00:05:29.160] is really what we think works. [00:05:29.160 --> 00:05:31.320] So what's the basic program here? [00:05:31.320 --> 00:05:35.380] You begin with a large corpus of sequence data, say text, [00:05:35.380 --> 00:05:38.080] choose a random window within that corpus, [00:05:38.080 --> 00:05:41.680] and then algorithmically construct inputs [00:05:41.680 --> 00:05:44.600] and target outputs on the fly. [00:05:44.600 --> 00:05:46.720] And I wanna underscore, I actually think doing it on the fly [00:05:46.720 --> 00:05:49.320] is part of what makes this method so powerful. [00:05:49.320 --> 00:05:51.940] And you have your choice about how you'll do it on the fly. [00:05:51.940 --> 00:05:54.760] You might be taking a word in the corpus [00:05:54.760 --> 00:05:56.620] and trying to predict its neighbors, [00:05:56.620 --> 00:05:58.780] which is the so-called skip-gram model. [00:05:58.780 --> 00:06:01.220] You might be doing something like fill in the blank, [00:06:01.220 --> 00:06:03.940] or you might be trying to predict the end of the sequence, [00:06:03.940 --> 00:06:06.480] which is sort of the classic language modeling problem. [00:06:06.480 --> 00:06:08.740] All of these fit in this description. [00:06:08.740 --> 00:06:11.960] And if you repeat that a few billion times, [00:06:11.960 --> 00:06:14.160] it seems to work really well. [00:06:14.160 --> 00:06:17.080] But that's where we get into the hard part. [00:06:17.080 --> 00:06:19.600] - Yeah, so I think one of the things [00:06:19.600 --> 00:06:21.600] that we really explored in this work [00:06:21.600 --> 00:06:24.920] and sort of work we were doing concurrently with this [00:06:24.920 --> 00:06:27.200] is how effectively could we make [00:06:27.200 --> 00:06:29.880] sort of weakly synchronized, asynchronous updates [00:06:29.880 --> 00:06:31.920] to a large model work? [00:06:31.920 --> 00:06:34.720] And Tomasz, our first author, [00:06:34.720 --> 00:06:37.700] had been exploring these word embedding ideas [00:06:37.700 --> 00:06:40.740] on a single machine version that he implemented in C, [00:06:40.740 --> 00:06:44.680] of both the skip-gram and the continuous bag of words [00:06:44.680 --> 00:06:45.680] objective functions. [00:06:45.680 --> 00:06:48.280] And he actually did a fair amount of work [00:06:48.280 --> 00:06:51.620] to scale this up to be a very high-performance implementation [00:06:51.620 --> 00:06:53.620] using all the cores on a single machine, [00:06:53.620 --> 00:06:55.960] so about 20 different cores at that time [00:06:55.960 --> 00:06:57.300] with almost no synchronization. [00:06:57.300 --> 00:06:59.720] So you just kind of blindly update the embedding [00:06:59.720 --> 00:07:04.540] that was sort of a large 2D array in memory. [00:07:04.540 --> 00:07:06.800] And then he was able to have about 20 cores [00:07:06.800 --> 00:07:08.000] on these multi-core machines [00:07:08.000 --> 00:07:10.180] simultaneously updating this shared representation [00:07:10.180 --> 00:07:13.020] and get quite good embedding representations. [00:07:13.020 --> 00:07:17.440] Now, one of the things that we observed [00:07:17.440 --> 00:07:20.160] was every time we made the dimensionality [00:07:20.160 --> 00:07:21.920] of the word vectors larger, [00:07:21.920 --> 00:07:24.480] and every time we trained on more data, [00:07:24.480 --> 00:07:26.040] things got better, right? [00:07:26.040 --> 00:07:29.440] This is the lesson of a lot of the last 10 years [00:07:29.440 --> 00:07:30.780] of deep learning work, [00:07:30.780 --> 00:07:33.740] is scaling actually gives you much better results. [00:07:33.740 --> 00:07:38.500] Fortunately, a bunch of us were simultaneously working [00:07:38.500 --> 00:07:40.380] on a highly scalable system [00:07:40.380 --> 00:07:42.240] for distributed training of neural networks. [00:07:42.240 --> 00:07:44.540] So we decided to take the single machine implementation [00:07:44.540 --> 00:07:48.120] that Tamás had built for these word embedding questions [00:07:48.120 --> 00:07:52.280] and implement that in our distributed framework. [00:07:52.280 --> 00:07:54.800] And so the work we were doing kind of [00:07:54.800 --> 00:07:56.680] just a bit before this work [00:07:56.680 --> 00:07:58.980] was this large-scale distributed deep networks [00:07:58.980 --> 00:08:00.920] where we were exploring distributed training [00:08:00.920 --> 00:08:05.920] of large-scale models, mostly for vision and for speech. [00:08:05.920 --> 00:08:08.560] And really the motivation was how can we scale training [00:08:08.560 --> 00:08:11.080] on these systems to thousands of machines? [00:08:11.080 --> 00:08:16.560] We actually titled this disbelief internally, [00:08:16.560 --> 00:08:19.280] so named because it was a distributed system, [00:08:19.280 --> 00:08:21.020] but also 'cause a bunch of people were skeptical [00:08:21.020 --> 00:08:22.600] that it would work. [00:08:22.600 --> 00:08:24.940] And turns out it did work, which is nice. [00:08:24.940 --> 00:08:28.120] So the basic idea behind disbelief [00:08:28.120 --> 00:08:30.600] is you have some set of parameters [00:08:30.600 --> 00:08:33.440] that are being represented on some set of machines, [00:08:33.440 --> 00:08:35.960] and then you have independent replicas of the model [00:08:35.960 --> 00:08:38.040] where you fetch the current state of the parameters, [00:08:38.040 --> 00:08:39.740] you do some computation on the model, [00:08:39.740 --> 00:08:42.360] and then you update the parameters [00:08:42.360 --> 00:08:45.440] by sending a gradient back to the parameter servers. [00:08:45.440 --> 00:08:47.480] And in large-scale setups, [00:08:47.480 --> 00:08:49.440] we were using tens to hundreds of machines [00:08:49.440 --> 00:08:51.640] to hold the distributed state of the parameters, [00:08:51.640 --> 00:08:53.720] and hundreds to thousands of machines [00:08:53.720 --> 00:08:57.080] to hold the sort of independent workers of the model. [00:08:57.080 --> 00:08:58.200] And so that really meant you had [00:08:58.200 --> 00:09:01.120] 1,000 to 10,000 simultaneous threads [00:09:01.120 --> 00:09:02.640] kind of updating the model [00:09:02.640 --> 00:09:04.840] for the word embedding kind of work. [00:09:04.840 --> 00:09:07.900] And we were using 300 to 1,000 dimensional embeddings [00:09:07.900 --> 00:09:11.560] for a lot of things, 100K to million item vocabularies, [00:09:11.560 --> 00:09:13.880] and even beyond for a lot of internal uses. [00:09:13.880 --> 00:09:15.520] It turns out you can make vocabularies [00:09:15.520 --> 00:09:17.900] out of lots of things, not just words, [00:09:17.900 --> 00:09:20.720] but particular videos that people have watched, [00:09:20.720 --> 00:09:21.600] or all kinds of things, [00:09:21.600 --> 00:09:24.560] and use kind of similar approaches [00:09:24.560 --> 00:09:26.140] than just language modeling. [00:09:27.060 --> 00:09:29.300] And now back to Greg. [00:09:29.300 --> 00:09:33.020] - And so this kind of provocative disrespect [00:09:33.020 --> 00:09:35.420] for locking and synchronization [00:09:35.420 --> 00:09:37.860] was the biggest single enabler [00:09:37.860 --> 00:09:39.460] of being able to do this work. [00:09:39.460 --> 00:09:40.980] But there were other things that we did [00:09:40.980 --> 00:09:42.300] that tried to focus compute [00:09:42.300 --> 00:09:43.860] to where it really actually made a difference [00:09:43.860 --> 00:09:46.280] in terms of model representation and quality. [00:09:46.280 --> 00:09:50.860] So for example, the meaning of tokens that are uncommon [00:09:50.860 --> 00:09:53.460] is actually often more informative than common ones, [00:09:53.460 --> 00:09:55.260] and common ones are super easy to learn [00:09:55.260 --> 00:09:57.460] 'cause you get a lot of chances at them. [00:09:57.460 --> 00:10:00.340] So we would probabilistically discard tokens [00:10:00.340 --> 00:10:01.500] related to their frequency, [00:10:01.500 --> 00:10:03.660] ignoring common tokens more often. [00:10:03.660 --> 00:10:08.660] And you could apply that both as inputs and in targets. [00:10:08.660 --> 00:10:09.640] Another thing that we did [00:10:09.640 --> 00:10:11.740] was we found that favoring objectives and models [00:10:11.740 --> 00:10:16.700] that were informative for the ultimate task, [00:10:16.700 --> 00:10:18.340] but were faster to compute was better. [00:10:18.340 --> 00:10:21.340] And so in our paper, you can see we go through softmax [00:10:21.340 --> 00:10:22.580] and then an approximation to that, [00:10:22.580 --> 00:10:24.060] through hierarchical softmax, [00:10:24.060 --> 00:10:26.660] and then noise contrastive estimation, [00:10:26.660 --> 00:10:27.900] which is an even faster version, [00:10:27.900 --> 00:10:29.720] and then ILLIA came up with negative sampling, [00:10:29.720 --> 00:10:31.900] which is an even faster, faster, faster version. [00:10:31.900 --> 00:10:33.660] We saw that quality went up [00:10:33.660 --> 00:10:36.620] every time that we were able to make it simpler and faster. [00:10:36.620 --> 00:10:41.180] We also found that you could use tools like tokenization [00:10:41.180 --> 00:10:43.820] to focus computation in the part that was interesting. [00:10:43.820 --> 00:10:45.140] One of the things that we used it for [00:10:45.140 --> 00:10:47.440] was to try to deal with phrase representation. [00:10:47.440 --> 00:10:49.860] So in English, compound concepts and nouns [00:10:49.860 --> 00:10:51.940] are often represented by multiple words. [00:10:51.940 --> 00:10:53.980] And so we just had a very simple heuristic [00:10:53.980 --> 00:10:58.260] that allowed us to build bigrams out of terms [00:10:58.260 --> 00:11:00.900] that were each individually not super frequent, [00:11:00.900 --> 00:11:03.560] but were co-occurring much more frequently together [00:11:03.560 --> 00:11:04.780] than you would expect. [00:11:04.780 --> 00:11:08.400] And many other authors have used tokenization schemes [00:11:08.400 --> 00:11:10.200] in these systems to great benefit, [00:11:10.200 --> 00:11:12.620] dealing with everything from contractions to declinations. [00:11:12.620 --> 00:11:15.220] And I just think it's important for us to not overlook [00:11:15.220 --> 00:11:18.700] that when we're processing text, we begin with tokenization. [00:11:18.700 --> 00:11:22.240] But then to the point of getting concepts [00:11:22.240 --> 00:11:24.460] to be n-dimensional vectors and how it is [00:11:24.460 --> 00:11:26.140] that this is so powerful. [00:11:26.140 --> 00:11:28.300] And I was actually trained as a neuroscientist. [00:11:28.300 --> 00:11:32.260] And so I saw this come up as ideas from a long time ago, [00:11:32.260 --> 00:11:36.060] from the '80s about maybe concepts could be represented [00:11:36.060 --> 00:11:40.540] in a dense vector space and that operators [00:11:40.540 --> 00:11:43.340] in that vector space or geometric relationships [00:11:43.340 --> 00:11:46.460] in that vector space actually meant something. [00:11:46.460 --> 00:11:47.920] But that was simply a conjecture. [00:11:47.920 --> 00:11:50.340] And then lo and behold, when we took these representations [00:11:50.340 --> 00:11:52.780] that we had learned in a semi-supervised fashion [00:11:52.780 --> 00:11:55.540] and investigated what was inside by, for example, [00:11:55.540 --> 00:11:58.260] flattening them into two dimensions using PCA, [00:11:58.260 --> 00:12:02.660] we found that syntactic relationships [00:12:02.660 --> 00:12:04.240] were represented geometrically, [00:12:04.240 --> 00:12:06.180] like these similar triangles, [00:12:06.180 --> 00:12:07.780] representing the tenses of verbs, [00:12:07.780 --> 00:12:10.140] and that even arbitrary semantic relationships, [00:12:10.140 --> 00:12:12.340] like the relationship between countries and capitals [00:12:12.340 --> 00:12:16.260] or diseases and drugs, were also represented geometrically [00:12:16.260 --> 00:12:18.820] in this space as similar displacements. [00:12:18.820 --> 00:12:21.180] And that was really powerful. [00:12:21.180 --> 00:12:23.540] And then Tomás and Ilya were able to show [00:12:23.540 --> 00:12:24.920] that you could do these cute tricks, [00:12:24.920 --> 00:12:28.380] like solve analogies with simple vector arithmetic. [00:12:28.380 --> 00:12:30.180] By adding and subtracting vectors, [00:12:30.180 --> 00:12:34.240] you could see that sushi is to Japan [00:12:34.240 --> 00:12:36.420] as bratwurst is to Germany, [00:12:36.420 --> 00:12:39.500] well, at least according to the language model. [00:12:39.500 --> 00:12:42.720] And in fact, you could even just do simple addition [00:12:42.720 --> 00:12:45.760] to imagine combining concepts and discovering [00:12:45.760 --> 00:12:48.680] what concept is nearby in this vector space. [00:12:48.680 --> 00:12:53.500] So for example, putting together Russian and river, [00:12:53.500 --> 00:12:55.960] you get tokens like volga, river. [00:12:55.960 --> 00:13:00.460] - Okay, so summing it all up, [00:13:00.460 --> 00:13:01.660] what did we learn in these papers? [00:13:01.660 --> 00:13:02.900] Let's go back to the five points [00:13:02.900 --> 00:13:04.700] that Greg talked about in the beginning. [00:13:04.700 --> 00:13:06.500] So semi-supervised objectives applied [00:13:06.500 --> 00:13:09.660] to a large text corpora are pretty important [00:13:09.660 --> 00:13:11.100] in natural language understanding. [00:13:11.100 --> 00:13:12.800] I would say definitely true today. [00:13:12.800 --> 00:13:16.700] Fast, parallel, weakly synchronized computation [00:13:16.700 --> 00:13:17.900] dominates in ML. [00:13:17.900 --> 00:13:22.000] Parallel, definitely. [00:13:22.000 --> 00:13:24.920] I would say larger scale specialized ML hardware [00:13:24.920 --> 00:13:28.360] has really enabled fully synchronized approaches to scale, [00:13:28.360 --> 00:13:31.440] even to the scale of models that we're training today. [00:13:31.440 --> 00:13:33.640] But I personally think that asynchronous approaches [00:13:33.640 --> 00:13:34.520] are gonna make a comeback, [00:13:34.520 --> 00:13:37.160] because I think we're sort of close to where [00:13:37.160 --> 00:13:38.840] we're gonna have to start reconsidering [00:13:38.840 --> 00:13:41.000] some of these asynchronous approaches [00:13:41.000 --> 00:13:43.200] to training very large models. [00:13:43.200 --> 00:13:44.960] Focus compute on the aspects of learning [00:13:44.960 --> 00:13:46.200] that need improvement. [00:13:46.200 --> 00:13:48.720] Yeah, simpler, more parallel methods win out [00:13:48.720 --> 00:13:51.460] over more complex, less parallelizable models. [00:13:51.460 --> 00:13:53.560] You know, Word2vec versus RNNs. [00:13:53.560 --> 00:13:55.240] Transformers versus LSTMs. [00:13:55.240 --> 00:13:57.080] I think this is a good lesson [00:13:57.080 --> 00:13:58.520] as we're thinking about future improvements [00:13:58.520 --> 00:13:59.560] to these things. [00:13:59.560 --> 00:14:01.440] Tokenization can be used to solve [00:14:01.440 --> 00:14:03.600] seemingly nuanced problems. [00:14:03.600 --> 00:14:05.680] Yeah, more powerful models on top [00:14:05.680 --> 00:14:07.620] have actually pushed tokenization [00:14:07.620 --> 00:14:10.320] in the opposite direction of our phrase-based vocabulary, [00:14:10.320 --> 00:14:15.320] where we now have kind of sub-word sort of tokenization, [00:14:15.640 --> 00:14:17.900] and that actually has seemed to work pretty well [00:14:17.900 --> 00:14:19.320] for some of these models [00:14:19.320 --> 00:14:21.880] that have more complex attention mechanisms on top. [00:14:21.880 --> 00:14:25.060] And treating language as a sequence of dense vectors [00:14:25.060 --> 00:14:26.720] is more powerful than expected. [00:14:26.720 --> 00:14:29.400] Definitely true today. [00:14:29.400 --> 00:14:31.640] So we're really honored to receive this award. [00:14:31.640 --> 00:14:34.080] Thanks to the committee that selected the work. [00:14:34.080 --> 00:14:35.080] We're really honored. [00:14:35.080 --> 00:14:38.120] And thanks to our co-authors who couldn't be here today, [00:14:38.120 --> 00:14:42.000] and there's their pictures, Tomas, Ilya, and Kai. [00:14:42.000 --> 00:14:45.280] Thank you for this delightful work and co-authoring. [00:14:45.280 --> 00:14:46.440] Were we still so young? [00:14:46.440 --> 00:14:48.240] Thanks, everybody. [00:14:48.240 --> 00:14:49.080] - Yeah, we picked the younger ones. [00:14:49.080 --> 00:14:52.440] (audience applauding) [00:14:52.440 --> 00:14:54.320] - By the way, there was some discussion at NeurIPS [00:14:54.320 --> 00:14:57.600] around what would be the 2024 Test of Time winner. [00:14:57.600 --> 00:15:01.120] There was some contention for GANs by Ian Goodfellow, [00:15:01.120 --> 00:15:04.080] but probably it's going to go to the sequence-to-sequence [00:15:04.080 --> 00:15:07.240] paper because that is most influential [00:15:07.240 --> 00:15:08.440] to language models today. [00:15:08.440 --> 00:15:10.420] The only thing I know for sure is that I know [00:15:10.420 --> 00:15:13.920] what's gonna be the Test of Time award winner for 2027. [00:15:13.920 --> 00:15:15.760] Up next are the best paper awards from this year. [00:15:15.760 --> 00:15:16.760] There are two papers chosen, [00:15:16.760 --> 00:15:19.960] but probably the most relevant for AI engineers [00:15:19.960 --> 00:15:22.520] is the Mirage paper, or in other words, [00:15:22.520 --> 00:15:25.360] Our Emergent Abilities of Large Language Models, a Mirage. [00:15:25.360 --> 00:15:27.120] And here is Schaefer et al. [00:15:27.120 --> 00:15:31.320] - My name is Ryland Schaefer, and this is our NeurIPS paper, [00:15:31.320 --> 00:15:35.580] Our Emergent Abilities of Large Language Models, a Mirage. [00:15:35.580 --> 00:15:37.600] This is joint work with Brando Miranda [00:15:37.600 --> 00:15:39.400] and Professor Sanmi Koyejo. [00:15:40.480 --> 00:15:45.480] Our paper is a story about predictability and surprise. [00:15:45.480 --> 00:15:48.840] Our story begins with predictability. [00:15:48.840 --> 00:15:51.220] As many of you know, several years ago, [00:15:51.220 --> 00:15:54.740] researchers observed a striking phenomenon, [00:15:54.740 --> 00:15:59.000] that as you fed large networks more and more data, [00:15:59.000 --> 00:16:02.620] the loss improved in a predictable manner. [00:16:02.620 --> 00:16:05.000] But it wasn't just the test data. [00:16:05.000 --> 00:16:07.600] Other researchers observed that other quantities, [00:16:07.600 --> 00:16:12.320] scaling compute, scaling dataset size, scaling parameters, [00:16:12.320 --> 00:16:14.600] yielded predictable improvements [00:16:14.600 --> 00:16:17.140] in the performance of large networks. [00:16:17.140 --> 00:16:19.880] This was incredibly important because it told us [00:16:19.880 --> 00:16:22.260] that if you fed more into these models, [00:16:22.260 --> 00:16:24.420] you knew what you would get. [00:16:24.420 --> 00:16:26.060] That's extremely useful. [00:16:26.060 --> 00:16:30.140] But approximately three years ago, [00:16:30.140 --> 00:16:32.000] this story was turned on its head. [00:16:32.000 --> 00:16:35.280] There was a new story in town, [00:16:35.280 --> 00:16:39.120] a story of surprise in large language models. [00:16:39.120 --> 00:16:41.940] Specifically, perhaps the first instance of this [00:16:41.940 --> 00:16:45.660] was in the GPT-3 paper, where the authors observed [00:16:45.660 --> 00:16:48.680] that you might try having language models solve a task, [00:16:48.680 --> 00:16:51.320] like arithmetic, and you make them larger [00:16:51.320 --> 00:16:54.400] and larger and larger, and they're unable to do this task. [00:16:54.400 --> 00:16:58.960] But then, at some seemingly unforeseeable model scale, [00:16:58.960 --> 00:17:02.280] performance skyrockets, almost to ceiling, [00:17:02.280 --> 00:17:04.440] something that was unpredictable. [00:17:04.440 --> 00:17:06.320] But it wasn't just on arithmetic. [00:17:06.320 --> 00:17:10.080] It was also on many other tasks, IPA transliterate, [00:17:10.080 --> 00:17:12.840] word unscrambling, Persian question answering, [00:17:12.840 --> 00:17:15.120] all of these tasks across a variety [00:17:15.120 --> 00:17:17.440] of language model families. [00:17:17.440 --> 00:17:19.440] All of them seem to display [00:17:19.440 --> 00:17:22.780] these miraculous emergent abilities. [00:17:22.780 --> 00:17:24.960] What are emergent abilities? [00:17:24.960 --> 00:17:28.200] Emergent abilities were defined by their authors [00:17:28.200 --> 00:17:32.300] as abilities that are not present in smaller scale models, [00:17:32.300 --> 00:17:35.720] but that are present in larger scale models. [00:17:35.720 --> 00:17:40.120] Critically, emergent abilities cannot be predicted [00:17:40.120 --> 00:17:43.760] by simply extrapolating the performance improvements [00:17:43.760 --> 00:17:45.400] on smaller scale models. [00:17:45.400 --> 00:17:48.620] These emergent abilities raised [00:17:48.620 --> 00:17:51.600] several interesting research questions. [00:17:51.600 --> 00:17:56.120] Questions like, what controls which abilities will emerge? [00:17:56.120 --> 00:17:59.280] What controls when abilities will emerge? [00:17:59.280 --> 00:18:03.440] How can we make desirable abilities emerge faster? [00:18:03.440 --> 00:18:07.600] And how can we ensure undesirable abilities never emerge? [00:18:07.600 --> 00:18:10.160] These questions not only are fundamental scientific [00:18:10.160 --> 00:18:13.240] questions of interest to the machine learning community, [00:18:13.240 --> 00:18:15.140] but these are also fundamental questions [00:18:15.140 --> 00:18:18.140] for those interested in governmental policy or economics. [00:18:18.140 --> 00:18:21.200] What our paper asked is whether or not [00:18:21.200 --> 00:18:24.360] the story of emergent abilities is complete. [00:18:26.760 --> 00:18:30.160] Specifically, if you look at these emergent abilities, [00:18:30.160 --> 00:18:31.620] you might notice something. [00:18:31.620 --> 00:18:34.340] That if you hone in on the metrics, [00:18:34.340 --> 00:18:37.520] all of these metrics are quite harsh. [00:18:37.520 --> 00:18:39.480] They give no partial credit. [00:18:39.480 --> 00:18:41.120] Exact match, for instance. [00:18:41.120 --> 00:18:44.080] Either you exactly output the correct answer, [00:18:44.080 --> 00:18:45.360] or you do not. [00:18:45.360 --> 00:18:47.440] There is no in between. [00:18:47.440 --> 00:18:51.340] And so, it seemed, when we looked closer, [00:18:51.340 --> 00:18:54.280] that many emergent abilities appeared under metrics [00:18:54.280 --> 00:18:57.000] that non-linearly or discontinuously [00:18:57.000 --> 00:19:00.080] scored models' performance. [00:19:00.080 --> 00:19:03.660] For instance, we found over 90% of emergent abilities [00:19:03.660 --> 00:19:06.160] on Google's large-scale Big Bench, [00:19:06.160 --> 00:19:08.440] we found that over 90% of emergent abilities [00:19:08.440 --> 00:19:10.880] observed under two metrics. [00:19:10.880 --> 00:19:13.840] One of those metrics, for those who haven't seen this, [00:19:13.840 --> 00:19:16.240] is called multiple-choice grade. [00:19:16.240 --> 00:19:19.480] It's like taking an A through D multiple-choice question. [00:19:19.480 --> 00:19:21.520] You get a score of one if you put [00:19:21.520 --> 00:19:23.960] the highest probability mass on that answer, [00:19:23.960 --> 00:19:25.760] and zero otherwise. [00:19:25.760 --> 00:19:27.960] The other metric was exact string match, [00:19:27.960 --> 00:19:31.440] where again, one point, if you get it exactly right, [00:19:31.440 --> 00:19:33.520] zero otherwise. [00:19:33.520 --> 00:19:36.680] This raised the specter that emergent abilities [00:19:36.680 --> 00:19:39.480] might not be due to fundamental changes [00:19:39.480 --> 00:19:41.520] in model with scale, [00:19:41.520 --> 00:19:45.000] but due to our evaluations of said models. [00:19:45.000 --> 00:19:49.280] So what exactly is this alternative that I'm positing? [00:19:49.280 --> 00:19:51.760] What is our alternative hypothesis? [00:19:51.760 --> 00:19:53.600] Let's walk through it. [00:19:53.600 --> 00:19:57.120] First of all, let's just suppose that the test loss falls [00:19:57.120 --> 00:19:59.640] as we increase the number of parameters in our models. [00:19:59.640 --> 00:20:01.760] So for example, motivated by power-loss scaling, [00:20:01.760 --> 00:20:04.120] we might assume that the cross-entropy loss [00:20:04.120 --> 00:20:06.040] as a function of the number of parameters [00:20:06.040 --> 00:20:07.800] is some power-loss. [00:20:07.800 --> 00:20:09.640] What that means is if we visualize [00:20:09.640 --> 00:20:11.440] the number of model parameters [00:20:11.440 --> 00:20:15.000] against the cross-entropy loss in log-log space, [00:20:15.000 --> 00:20:17.440] we observe a very predictable linear trend. [00:20:17.440 --> 00:20:22.680] In step two, we compute the probability mass [00:20:22.680 --> 00:20:24.360] that is placed on the correct token [00:20:24.360 --> 00:20:26.200] as a function of parameters. [00:20:26.200 --> 00:20:27.320] So how can we do this? [00:20:27.320 --> 00:20:31.400] Well, we know the definitional form of cross-entropy, [00:20:31.400 --> 00:20:33.200] and we know that we can substitute in [00:20:33.200 --> 00:20:35.000] our power-loss scaling. [00:20:35.000 --> 00:20:37.240] So I can rearrange, and when I plot this, [00:20:37.240 --> 00:20:40.640] what I see is that as model parameters get larger, [00:20:40.640 --> 00:20:44.120] the probability mass that gets placed on the correct token [00:20:44.120 --> 00:20:46.480] asymptotes towards one. [00:20:46.480 --> 00:20:48.440] And everybody is comfortable with this. [00:20:48.440 --> 00:20:53.000] So how do we go from this to an emergent capability? [00:20:53.000 --> 00:20:55.640] The answer is we might choose a metric [00:20:55.640 --> 00:20:59.560] that non-linearly scores model performance. [00:20:59.560 --> 00:21:01.200] For example, suppose that we want to add [00:21:01.200 --> 00:21:03.280] two five-digit numbers, [00:21:03.280 --> 00:21:06.000] and we're gonna measure performance with accuracy. [00:21:06.000 --> 00:21:09.200] What scaling should we expect? [00:21:09.200 --> 00:21:12.420] Well, the answer is that unless you get every token correct, [00:21:12.420 --> 00:21:14.320] you get zero points. [00:21:14.320 --> 00:21:16.060] Ergo, to score one point, [00:21:16.060 --> 00:21:18.680] it's going to be the per-token probability [00:21:18.680 --> 00:21:21.640] approximately exponentiated to however many tokens [00:21:21.640 --> 00:21:23.440] you need to get correct. [00:21:23.440 --> 00:21:26.040] So what happens is this graph on the right [00:21:26.040 --> 00:21:28.920] that we like and know gets transformed [00:21:28.920 --> 00:21:31.060] into something that becomes much less predictable [00:21:31.060 --> 00:21:32.960] with model scaling. [00:21:32.960 --> 00:21:37.300] And indeed, this toy model qualitatively reproduces [00:21:37.300 --> 00:21:39.940] what's been observed empirically at large scale. [00:21:39.940 --> 00:21:43.420] But could we have done something differently? [00:21:43.420 --> 00:21:46.560] Yes, suppose we had done the evaluation differently. [00:21:46.560 --> 00:21:48.120] Suppose that we had chosen a different metric, [00:21:48.120 --> 00:21:51.520] one that linearly scales model performance. [00:21:51.520 --> 00:21:53.880] So for example, I might instead count [00:21:53.880 --> 00:21:56.480] merely the number of mistakes that the language model makes. [00:21:56.480 --> 00:21:59.640] For those in NLP, you might call this an edit distance. [00:21:59.640 --> 00:22:02.000] And what that then means is that the edit distance [00:22:02.000 --> 00:22:05.880] scales approximately linearly with the output length. [00:22:05.880 --> 00:22:07.600] And so if we look at this, [00:22:07.600 --> 00:22:10.280] instead what we find is when we plot model parameters [00:22:10.280 --> 00:22:12.080] versus the number of incorrect tokens, [00:22:12.080 --> 00:22:14.300] we find a very nice predictable trend [00:22:14.300 --> 00:22:17.840] that asymptotes towards zero as you make models bigger. [00:22:17.840 --> 00:22:19.900] So nothing has fundamentally changed. [00:22:19.900 --> 00:22:22.660] From one viewpoint, we saw a seemingly emergent ability. [00:22:22.660 --> 00:22:24.860] From a different viewpoint, we removed it. [00:22:24.860 --> 00:22:28.720] Of course, it's not just about linear [00:22:28.720 --> 00:22:29.940] and non-linear metrics. [00:22:29.940 --> 00:22:32.000] It can also be discontinuous metrics. [00:22:32.000 --> 00:22:35.940] So for example, let's consider that multiple choice metric. [00:22:35.940 --> 00:22:37.740] So multiple choice again is you get one [00:22:37.740 --> 00:22:39.420] if you place the highest probability mass [00:22:39.420 --> 00:22:41.060] on the correct option. [00:22:41.060 --> 00:22:44.760] And what that scaling looks like is you're at chance [00:22:44.760 --> 00:22:47.180] up until some unforeseeable critical threshold, [00:22:47.180 --> 00:22:49.340] at which point you jump to ceiling. [00:22:49.340 --> 00:22:51.020] And this again qualitatively matches [00:22:51.020 --> 00:22:53.120] what's been observed empirically at scale. [00:22:53.120 --> 00:22:57.060] So if we had done the evaluation differently, [00:22:57.060 --> 00:22:59.740] we could have chosen a continuous metric like Breyer's score, [00:22:59.740 --> 00:23:01.660] which is just the mean squared error here [00:23:01.660 --> 00:23:03.460] between one and the probability mass, [00:23:03.460 --> 00:23:06.160] and then we find a very nice quadratic. [00:23:06.160 --> 00:23:07.800] So to summarize this together, [00:23:07.800 --> 00:23:09.820] we started with power loss scaling. [00:23:09.820 --> 00:23:12.740] We figured out, we computed what the probability mass [00:23:12.740 --> 00:23:14.220] on the correct token is. [00:23:14.220 --> 00:23:17.940] If we chose a non-linear metric, we see an emergent ability. [00:23:17.940 --> 00:23:20.500] But if we chose a linear metric, we did not. [00:23:20.500 --> 00:23:22.700] Similarly, if we chose a discontinuous metric, [00:23:22.700 --> 00:23:23.680] emergent ability. [00:23:23.680 --> 00:23:26.540] If we choose a continuous metric, we do not. [00:23:26.540 --> 00:23:28.500] And so this is our alternative hypothesis [00:23:28.500 --> 00:23:29.740] for emergent abilities. [00:23:29.740 --> 00:23:32.760] Now, of course, to summarize this, [00:23:32.760 --> 00:23:34.900] there's basically three factors at play here. [00:23:34.900 --> 00:23:37.300] One of them is the metrics that I focused on. [00:23:37.300 --> 00:23:39.180] Another one is that of statistics [00:23:39.180 --> 00:23:41.780] about needing sufficient resolution, [00:23:41.780 --> 00:23:45.260] measuring discreteness in order to accurately estimate [00:23:45.260 --> 00:23:46.500] the performance of models. [00:23:46.500 --> 00:23:49.060] And then third and finally, the third confounding factor [00:23:49.060 --> 00:23:52.500] is evaluating too few small and medium-sized models. [00:23:52.500 --> 00:23:57.140] So up till now, this has been Ryland's hypothesis. [00:23:57.140 --> 00:23:59.580] Do we have any actual evidence? [00:23:59.580 --> 00:24:00.820] And the answer is, in our paper, [00:24:00.820 --> 00:24:03.640] we considered three different types of evidence. [00:24:03.640 --> 00:24:05.180] We made and tested predictions [00:24:05.180 --> 00:24:07.560] using the largest publicly available model family [00:24:07.560 --> 00:24:09.200] at the time, GPT-3. [00:24:09.200 --> 00:24:11.540] We did a meta-analysis of published metrics [00:24:11.540 --> 00:24:13.300] and emergent abilities at Google's Big Bench. [00:24:13.300 --> 00:24:16.180] And third, we induced emergent abilities [00:24:16.180 --> 00:24:19.460] in toy minuscule networks on vision tasks. [00:24:19.460 --> 00:24:21.540] The reason why we did this is because prior to our paper, [00:24:21.540 --> 00:24:22.940] we didn't know of any work that had found [00:24:22.940 --> 00:24:25.440] emergent abilities in vision tasks. [00:24:25.440 --> 00:24:29.100] So to induce them intentionally was quite novel. [00:24:29.100 --> 00:24:30.900] So let's walk through this. [00:24:30.900 --> 00:24:32.220] Let's first talk about the predictions [00:24:32.220 --> 00:24:34.740] that the mathematical model makes. [00:24:34.740 --> 00:24:36.140] The first is that if you change the metric, [00:24:36.140 --> 00:24:38.100] you should get more predictable scaling. [00:24:38.100 --> 00:24:41.220] So here, again, model parameters versus accuracy. [00:24:41.220 --> 00:24:43.100] As I increase the number of tokens [00:24:43.100 --> 00:24:45.140] that the model needs to output correctly, [00:24:45.140 --> 00:24:46.380] we should expect to observe [00:24:46.380 --> 00:24:49.340] approximately geometric decrease in performance. [00:24:49.340 --> 00:24:51.740] So we start up here and then it falls. [00:24:51.740 --> 00:24:54.660] But if I change the metric to token edit distance, [00:24:54.660 --> 00:24:57.220] I should find this nice quasi-linear behavior. [00:24:57.220 --> 00:25:00.700] I'm now going to go test this in GPT-3. [00:25:00.700 --> 00:25:02.260] And that's precisely what we did. [00:25:02.260 --> 00:25:03.560] So here is accuracy. [00:25:03.560 --> 00:25:06.060] And again, here's the four models in the three family. [00:25:06.060 --> 00:25:08.420] And again, we find that as the target length gets longer, [00:25:08.420 --> 00:25:12.540] you find a decay geometrically in the length of the target. [00:25:12.540 --> 00:25:15.660] And that if I switch using the exact same data, [00:25:15.660 --> 00:25:18.340] fixed data, if I change the metric, [00:25:18.340 --> 00:25:20.900] I find very nice quasi-linear scaling. [00:25:20.900 --> 00:25:25.200] This is exactly what the toy mathematical model predicts. [00:25:25.200 --> 00:25:28.020] Moreover, there's a question about better statistics [00:25:28.020 --> 00:25:29.860] yielding more predictable scaling. [00:25:29.860 --> 00:25:32.120] What the toy model tells us is that [00:25:32.120 --> 00:25:35.080] when we said the tiny models are unable to do the task, [00:25:35.080 --> 00:25:36.800] that wasn't quite right. [00:25:36.800 --> 00:25:39.580] It was that their performance was so small, [00:25:39.580 --> 00:25:41.260] we didn't have sufficient resolution [00:25:41.260 --> 00:25:43.040] in order to estimate it. [00:25:43.040 --> 00:25:44.300] So what our toy model says is [00:25:44.300 --> 00:25:47.580] we really need to consider accuracy on a log scale. [00:25:47.580 --> 00:25:49.200] And to estimate these quantities, [00:25:49.200 --> 00:25:51.840] we need sufficient data to do so. [00:25:51.840 --> 00:25:53.700] So we scale up the amount of data. [00:25:53.700 --> 00:25:56.500] And again, we find that if we separate into log scale, [00:25:56.500 --> 00:25:58.820] we find a very, very nice separation [00:25:58.820 --> 00:26:00.120] with predictable behavior. [00:26:01.580 --> 00:26:04.200] Last, or second, we conducted a meta-analysis [00:26:04.200 --> 00:26:06.280] of emergent abilities on Google's Big Bench. [00:26:06.280 --> 00:26:09.120] And what we found is that across many, many, many metrics, [00:26:09.120 --> 00:26:11.400] we could not find emergent abilities. [00:26:11.400 --> 00:26:14.420] But on a small subset, to be specific, [00:26:14.420 --> 00:26:16.800] four of these, we found emergent abilities. [00:26:16.800 --> 00:26:18.440] That's what this little pie chart shows. [00:26:18.440 --> 00:26:20.360] So long story short, it seems like the metric [00:26:20.360 --> 00:26:21.440] is playing a fundamental role [00:26:21.440 --> 00:26:24.240] in producing these emergent abilities. [00:26:24.240 --> 00:26:27.000] And lastly, what we did is we induced [00:26:27.000 --> 00:26:29.100] emergent abilities in networks. [00:26:29.100 --> 00:26:31.200] So what we did is we did the simplest possible thing. [00:26:31.200 --> 00:26:33.740] We took a shallow, non-linear autoencoder [00:26:33.740 --> 00:26:35.540] and trained it on CIFAR-100. [00:26:35.540 --> 00:26:36.440] Everybody has done this [00:26:36.440 --> 00:26:38.560] in their intro to machine learning class. [00:26:38.560 --> 00:26:40.000] And what we did is we plotted [00:26:40.000 --> 00:26:41.640] the squared reconstruction error [00:26:41.640 --> 00:26:44.400] as a function of the number of parameters. [00:26:44.400 --> 00:26:46.440] But, and this looks very smooth and predictable, [00:26:46.440 --> 00:26:47.740] everybody has seen this. [00:26:47.740 --> 00:26:50.660] But if we define a discontinuous metric, [00:26:50.660 --> 00:26:53.040] so here, the model scores one [00:26:53.040 --> 00:26:56.040] if the reconstruction error is below some threshold, [00:26:56.040 --> 00:26:59.440] then you find very, very unpredictable behavior. [00:26:59.440 --> 00:27:02.960] And so even in a shallow, non-linear autoencoder, [00:27:02.960 --> 00:27:04.680] we can, again, qualitatively produce [00:27:04.680 --> 00:27:06.780] what seems to be an emergent behavior. [00:27:06.780 --> 00:27:08.060] There's two takeaways. [00:27:08.060 --> 00:27:09.720] One is for emergent abilities, [00:27:09.720 --> 00:27:11.880] it might be, in certain cases, [00:27:11.880 --> 00:27:13.940] the researcher's analyses [00:27:13.940 --> 00:27:15.080] that have produced these phenomenon. [00:27:15.080 --> 00:27:17.200] That's why we call it a mirage. [00:27:17.200 --> 00:27:18.700] But there's a more general lesson [00:27:18.700 --> 00:27:19.960] that I want to leave you with. [00:27:19.960 --> 00:27:21.480] The more general lesson is that [00:27:21.480 --> 00:27:23.440] if you want to predict changes [00:27:23.440 --> 00:27:25.880] in model capabilities with increasing scale, [00:27:25.880 --> 00:27:27.960] you need to consider the interplay [00:27:27.960 --> 00:27:29.920] between known scaling properties, [00:27:29.920 --> 00:27:32.240] the amount and quality of evaluation data, [00:27:32.240 --> 00:27:35.140] and the specific metrics and evaluation processes [00:27:35.140 --> 00:27:36.760] that you have available. [00:27:36.760 --> 00:27:40.000] So with that, and with gratitude to all my collaborators, [00:27:40.000 --> 00:27:42.440] and everyone here for attending, thank you. [00:27:42.440 --> 00:27:45.600] (audience applauding) [00:27:45.600 --> 00:27:48.520] - So for the purposes of this episode, [00:27:48.520 --> 00:27:50.520] we actually tried to do interviews [00:27:50.520 --> 00:27:52.480] at the poster sessions for each paper, [00:27:52.480 --> 00:27:54.960] but some we just didn't manage to find, [00:27:54.960 --> 00:27:58.160] or for the case of the emergent mirage paper, [00:27:58.160 --> 00:27:59.660] it was just way too popular. [00:27:59.660 --> 00:28:01.560] There were just so many people crowding out [00:28:01.560 --> 00:28:05.320] and listening to Ryan explain his paper again and again [00:28:05.320 --> 00:28:07.560] that we just couldn't get a proper question in. [00:28:07.560 --> 00:28:08.840] And I have to say, [00:28:08.840 --> 00:28:11.040] if I'm allowed to be a little bit critical, [00:28:11.040 --> 00:28:14.360] I'm a bit puzzled as to why this paper was the best paper. [00:28:14.360 --> 00:28:15.700] I mean, it's a good paper, [00:28:15.700 --> 00:28:18.960] but it doesn't really deny the existence of emergence. [00:28:18.960 --> 00:28:22.540] It just pointed out some methodological disagreements, [00:28:22.540 --> 00:28:25.480] which Jason Wei has also responded to. [00:28:25.480 --> 00:28:27.640] In other words, I don't really know [00:28:27.640 --> 00:28:31.000] if this paper affected literally anything in the field. [00:28:31.000 --> 00:28:33.200] So I don't know why it's best paper [00:28:33.200 --> 00:28:35.140] and not just a regular paper, [00:28:35.140 --> 00:28:36.760] but it's still a notable paper for sure. [00:28:36.760 --> 00:28:38.720] And it's very well done. [00:28:38.720 --> 00:28:41.280] Next, we have the runner up for best paper, [00:28:41.280 --> 00:28:43.480] which is direct preference optimization, [00:28:43.480 --> 00:28:46.060] which is a direct challenger to PPO. [00:28:46.060 --> 00:28:48.440] And you can hear directly from the authors. [00:28:48.440 --> 00:28:49.280] - Hi, everyone. [00:28:49.300 --> 00:28:52.660] My name's Eric, and I'm here with Raphael and Archit. [00:28:52.660 --> 00:28:53.620] And today we're gonna talk [00:28:53.620 --> 00:28:56.420] about direct preference optimization, [00:28:56.420 --> 00:29:01.420] which is this algorithm that simplifies RLHF, [00:29:01.420 --> 00:29:04.860] which is this algorithm framework [00:29:04.860 --> 00:29:08.740] that has sort of been taking the LLM world [00:29:08.740 --> 00:29:09.860] by storm recently. [00:29:09.860 --> 00:29:12.980] So to start, why are we even talking [00:29:12.980 --> 00:29:17.220] about reinforcement learning for language models now? [00:29:17.220 --> 00:29:19.320] It's not the first time people have been studying [00:29:19.320 --> 00:29:21.600] reinforcement learning in the context of language models, [00:29:21.600 --> 00:29:23.880] but the sort of simple answer to this question [00:29:23.880 --> 00:29:26.720] is that a few years ago, GPT-3 came onto the scene [00:29:26.720 --> 00:29:28.040] and it was sort of a big deal. [00:29:28.040 --> 00:29:30.800] And you probably, well, I'm an LLM person, [00:29:30.800 --> 00:29:32.840] but you probably heard from a lot of your researcher friends [00:29:32.840 --> 00:29:34.400] like, did you hear about this new model? [00:29:34.400 --> 00:29:37.600] And then last year, Chad GPT came on the scene [00:29:37.600 --> 00:29:39.760] and it was more like, at least I was like getting texts [00:29:39.760 --> 00:29:41.040] from my grandmother saying like, [00:29:41.040 --> 00:29:42.880] hey, have you seen this new model, right? [00:29:42.880 --> 00:29:45.040] And these are just like two different levels [00:29:45.040 --> 00:29:47.180] of permeation in the public consciousness. [00:29:47.180 --> 00:29:51.140] And so what is the difference between these two models? [00:29:51.140 --> 00:29:53.660] And really the main sort of key ingredient [00:29:53.660 --> 00:29:57.220] is this reinforcement learning from human feedback [00:29:57.220 --> 00:29:59.660] framework, which lets us sort of align the behaviors [00:29:59.660 --> 00:30:03.020] of the models more towards what people kind of want [00:30:03.020 --> 00:30:04.460] or expect. [00:30:04.460 --> 00:30:06.360] Okay, so to give a little bit of an overview [00:30:06.360 --> 00:30:10.060] of what sort of the existing RLHF pipeline looked like [00:30:10.060 --> 00:30:11.900] kind of when we started working on this project, [00:30:11.900 --> 00:30:13.420] so there are basically two main steps. [00:30:13.420 --> 00:30:16.420] So the first step is we're going to start [00:30:16.420 --> 00:30:20.240] with some reasonably behaved kind of imitation [00:30:20.240 --> 00:30:23.960] behavior clone policy, what we call pi theta SFT here, [00:30:23.960 --> 00:30:25.600] so supervised fine-tuned policy. [00:30:25.600 --> 00:30:29.400] We're gonna sample pairs of responses or trajectories [00:30:29.400 --> 00:30:33.340] from this policy conditioned on a prompt X, [00:30:33.340 --> 00:30:34.860] and that's how we're going to gather [00:30:34.860 --> 00:30:36.240] this data set of preferences. [00:30:36.240 --> 00:30:38.280] So we'll have an X and we'll have two Ys, [00:30:38.280 --> 00:30:39.920] and a human is gonna just label [00:30:39.920 --> 00:30:41.240] which Y they think is better. [00:30:41.620 --> 00:30:43.500] So they're just gonna give us this binary [00:30:43.500 --> 00:30:45.180] preference pair over responses, [00:30:45.180 --> 00:30:48.300] and we're going to use this data to fit a reward model. [00:30:48.300 --> 00:30:49.580] And then in the second step, we're just going [00:30:49.580 --> 00:30:52.140] to optimize a policy to maximize rewards. [00:30:52.140 --> 00:30:53.660] So that's just RL. [00:30:53.660 --> 00:30:55.980] Okay, so to look at this a little more closely [00:30:55.980 --> 00:30:58.700] in this first step, we get this feedback. [00:30:58.700 --> 00:31:01.220] It's these triples of a prompt and two responses. [00:31:01.220 --> 00:31:03.340] One is sort of the winner and one is the loser. [00:31:03.340 --> 00:31:05.380] And we're simply going to train a reward model [00:31:05.380 --> 00:31:07.220] with this binary classification loss [00:31:07.220 --> 00:31:08.060] on the preference data. [00:31:08.060 --> 00:31:10.120] So this is this Bradley-Terry model [00:31:10.120 --> 00:31:14.040] of discrete choice in humans from the '50s. [00:31:14.040 --> 00:31:16.080] But it has some nice properties, [00:31:16.080 --> 00:31:18.040] and it's relatively simple to understand, [00:31:18.040 --> 00:31:19.740] and we use this to fit this reward model. [00:31:19.740 --> 00:31:21.260] So we're just taking the difference and the rewards, [00:31:21.260 --> 00:31:25.400] and we have this sort of Boltzmann rational model here [00:31:25.400 --> 00:31:28.140] that we're fitting with maximum likelihood. [00:31:28.140 --> 00:31:31.040] Okay, and so now that we're done with this first step, [00:31:31.040 --> 00:31:33.120] what are we going to do with this reward model? [00:31:33.120 --> 00:31:34.680] Well, we're going to try to find a policy [00:31:34.680 --> 00:31:36.560] achieving high reward. [00:31:36.560 --> 00:31:38.300] And so, you know, ideally this reward model [00:31:38.300 --> 00:31:41.220] after we've done this supervised learning stage [00:31:41.220 --> 00:31:45.220] should represent goodness according to what humans want. [00:31:45.220 --> 00:31:46.620] And so we're just going to fit a policy [00:31:46.620 --> 00:31:48.100] that both achieves high reward, [00:31:48.100 --> 00:31:50.260] but also stays close to our original model, [00:31:50.260 --> 00:31:53.460] our reference model, or our supervised fine-tune model. [00:31:53.460 --> 00:31:57.260] And so that means we're going to try to find a policy here, [00:31:57.260 --> 00:32:00.900] pi theta, that generates samples that achieve high reward [00:32:00.900 --> 00:32:03.540] under our learned reward model, [00:32:03.540 --> 00:32:05.380] but also stays close to our original model, [00:32:05.380 --> 00:32:07.860] our reference model, because if you remember, [00:32:07.860 --> 00:32:10.540] we actually fit our reward model on samples [00:32:10.540 --> 00:32:12.260] that were annotated by humans, [00:32:12.260 --> 00:32:14.420] but these samples were generated by our reference model, [00:32:14.420 --> 00:32:15.880] our supervised fine-tune model, right? [00:32:15.880 --> 00:32:18.260] So we don't want our policy to drift too far away [00:32:18.260 --> 00:32:21.780] because we want to stay in the regime [00:32:21.780 --> 00:32:23.660] where our reward model's actually reliable. [00:32:23.660 --> 00:32:26.660] Okay, so now that we have this objective, [00:32:26.660 --> 00:32:29.340] we take some off-the-shelf RL algorithm, [00:32:29.340 --> 00:32:33.200] typically it's PPO, and we find a policy [00:32:33.200 --> 00:32:35.260] that optimizes these rewards. [00:32:35.260 --> 00:32:38.180] This is a very complicated procedure, [00:32:38.180 --> 00:32:40.660] so there's this nice figure in this recent paper [00:32:40.660 --> 00:32:43.220] showing sort of the full pipeline of just the PPO step, [00:32:43.220 --> 00:32:44.580] and there are a lot of moving pieces here. [00:32:44.580 --> 00:32:47.000] And so in light of sort of this complexity, [00:32:47.000 --> 00:32:49.340] we kind of set out to see if there's some way [00:32:49.340 --> 00:32:50.820] we can sort of use the structure of this problem [00:32:50.820 --> 00:32:51.780] to simplify things. [00:32:51.780 --> 00:32:58.940] - All right, so how the heck do we solve this optimization [00:32:58.940 --> 00:33:00.340] without reinforcement learning, [00:33:00.340 --> 00:33:02.860] or what we call direct preference optimization? [00:33:02.860 --> 00:33:06.480] Really the key here is that the optimization [00:33:06.480 --> 00:33:10.080] that was set up for RLHF has a closed-form optimal solution. [00:33:10.080 --> 00:33:11.720] Now this may look a bit intimidating, [00:33:11.720 --> 00:33:13.680] but it's really just the reference distribution [00:33:13.680 --> 00:33:15.880] re-weighted by the exponentiated reward. [00:33:15.880 --> 00:33:17.480] So if you have a good completion, [00:33:17.480 --> 00:33:19.700] you want to put more probability mass on it, [00:33:19.700 --> 00:33:21.320] and if you have a bad completion, why? [00:33:21.320 --> 00:33:23.480] You want to put less probability mass on this. [00:33:23.480 --> 00:33:25.480] This may look familiar, it's the Boltzmann distribution [00:33:25.480 --> 00:33:26.640] that you might have seen earlier, [00:33:26.640 --> 00:33:27.760] and it's very commonly used [00:33:27.760 --> 00:33:30.420] across machine learning and physics. [00:33:30.420 --> 00:33:32.600] So, but the key takeaway here is that [00:33:32.600 --> 00:33:37.120] every reward function R will induce an optimal policy, pi R. [00:33:37.120 --> 00:33:38.900] But there's a very nice way to view this identity [00:33:38.900 --> 00:33:41.880] through another perspective where we express [00:33:41.880 --> 00:33:44.560] the reward model in terms of the policy itself. [00:33:44.560 --> 00:33:47.680] So R pi X comma Y can be written as beta log ratio [00:33:47.680 --> 00:33:52.040] of pi by pi ref, plus the beta log partition function ZX. [00:33:52.040 --> 00:33:56.240] And this really is the key where every policy pi [00:33:56.240 --> 00:34:00.620] is optimal for some induced reward model, R pi. [00:34:00.620 --> 00:34:02.680] And this really is the key to DPO, [00:34:02.680 --> 00:34:04.840] because our key idea here is that [00:34:04.840 --> 00:34:06.600] you can fit this reward model, [00:34:06.600 --> 00:34:09.680] parameterize as a beta log ratio, to the preference data, [00:34:09.680 --> 00:34:12.560] and hopefully skip the RL process altogether. [00:34:12.560 --> 00:34:15.940] But the problem is that this log partition function [00:34:15.940 --> 00:34:18.360] is basically intractable, as you have to sum over [00:34:18.360 --> 00:34:21.320] all possible completions for a given instruction. [00:34:21.320 --> 00:34:23.080] So how do we get away from this? [00:34:24.820 --> 00:34:27.300] Now fortunately for us, the reward modeling loss [00:34:27.300 --> 00:34:29.100] that we looked at, the Bradley-Terry loss, [00:34:29.100 --> 00:34:31.180] only depends on the differences in the reward. [00:34:31.180 --> 00:34:34.940] Specifically, the reward for the preferred completion, [00:34:34.940 --> 00:34:38.620] subtracting the dispreferred completion's reward from that. [00:34:38.620 --> 00:34:42.840] Now, if you look at the induced reward difference, [00:34:42.840 --> 00:34:45.620] and if you plug in the DPO parameterization here, [00:34:45.620 --> 00:34:48.500] you can see that it only ends up depending on [00:34:48.500 --> 00:34:51.380] the DPO reward for the preferred completion, [00:34:51.380 --> 00:34:55.140] and subtract the DPO reward for the dispreferred completion. [00:34:55.140 --> 00:34:57.040] Now the more important thing here is that [00:34:57.040 --> 00:34:59.780] the partition function, which only depends [00:34:59.780 --> 00:35:02.460] on the instruction X, cancels out, [00:35:02.460 --> 00:35:05.260] as it only depends on the prompt. [00:35:05.260 --> 00:35:07.500] And this really is the key part here. [00:35:07.500 --> 00:35:10.760] And if you plug in this difference of rewards [00:35:10.760 --> 00:35:14.860] in the classification loss, you get the DPO loss function. [00:35:14.860 --> 00:35:18.240] And really in its essence, it's just a classification loss [00:35:18.240 --> 00:35:20.460] with a specific reward parameterization, [00:35:20.460 --> 00:35:22.940] which will give you the optimal policy [00:35:22.940 --> 00:35:25.620] for the original RLHF objective. [00:35:25.620 --> 00:35:31.220] So to go back to what Eric presented earlier, [00:35:31.220 --> 00:35:34.020] the RLHF is typically a two-step process. [00:35:34.020 --> 00:35:35.540] You first fit a reward model, [00:35:35.540 --> 00:35:37.420] and then you do some RL on top of it. [00:35:37.420 --> 00:35:40.100] Really what we are doing here is that [00:35:40.100 --> 00:35:42.180] we choose a specific parameterization, [00:35:42.180 --> 00:35:44.180] the DPO parameterization for the reward model, [00:35:44.180 --> 00:35:45.660] and we're still fitting the reward model [00:35:45.660 --> 00:35:47.140] exactly the same way. [00:35:47.140 --> 00:35:49.060] But you get the optimal policy in process, [00:35:49.060 --> 00:35:50.380] and you don't have to do the step two [00:35:50.380 --> 00:35:51.420] at any point of time. [00:35:51.420 --> 00:35:55.140] It's pretty useful to look at the DPO loss function [00:35:55.140 --> 00:35:56.900] through its gradient as well. [00:35:56.900 --> 00:35:58.860] Just to recall, it's still a classification loss, [00:35:58.860 --> 00:36:01.200] nothing changed in the two slides. [00:36:01.200 --> 00:36:04.180] And you're trying to maximize the difference [00:36:04.180 --> 00:36:05.500] between the rewards. [00:36:05.500 --> 00:36:07.300] But the gradient is really intuitive. [00:36:07.300 --> 00:36:08.540] Specifically, what we're trying to do [00:36:08.540 --> 00:36:12.100] is increase the log probability of the chosen completion, [00:36:12.100 --> 00:36:13.860] and we're trying to reduce the log probability [00:36:13.860 --> 00:36:15.420] of the rejected completion. [00:36:15.420 --> 00:36:17.420] The important part here is that we slow down [00:36:17.420 --> 00:36:19.700] the training on the preference pairs, [00:36:19.700 --> 00:36:22.020] where the induced reward model [00:36:22.020 --> 00:36:23.420] is already pointing the right direction, [00:36:23.420 --> 00:36:25.200] so you're not overfitting to the examples [00:36:25.200 --> 00:36:26.300] over and over again. [00:36:26.300 --> 00:36:27.700] But overall, it's really intuitive, [00:36:27.700 --> 00:36:29.700] as you're just doing up on the good examples [00:36:29.700 --> 00:36:31.140] and down on the bad examples. [00:36:31.140 --> 00:36:36.540] - And finally, moving to our experimental results. [00:36:36.540 --> 00:36:38.500] The first thing we really wanted to evaluate [00:36:38.500 --> 00:36:40.840] is how good of an optimizer that is [00:36:40.840 --> 00:36:42.860] for the core objective of reward [00:36:42.860 --> 00:36:46.220] versus divergence trade-off for these language models. [00:36:46.220 --> 00:36:47.780] So we started with this synthetic experiment, [00:36:47.780 --> 00:36:51.160] where the goal is to generate positive movie reviews [00:36:51.160 --> 00:36:55.640] on this IMDB dataset with a small GPT-2 base model. [00:36:55.640 --> 00:36:57.220] We created synthetic preferences [00:36:57.220 --> 00:36:59.440] by sampling several times from the base model [00:36:59.440 --> 00:37:02.420] and using a pre-trained score classifier [00:37:02.420 --> 00:37:06.140] to construct synthetic feedback pairs. [00:37:06.140 --> 00:37:08.220] Kind of immediately, the first thing we see [00:37:08.220 --> 00:37:12.220] is that DPO provides the best reward-KO trade-off. [00:37:12.220 --> 00:37:14.580] And PPO, although improves quite a bit, [00:37:14.580 --> 00:37:17.700] it doesn't quite match that efficiency of optimization, [00:37:17.700 --> 00:37:21.780] even when we provide it with the ground-truth scoring model [00:37:21.780 --> 00:37:25.120] that generated the preference data. [00:37:25.120 --> 00:37:27.500] And in addition, other sort of algorithms [00:37:27.500 --> 00:37:31.140] that are RL-free, avoid sort of the RL modeling approach, [00:37:31.140 --> 00:37:32.940] such as just fine-tuning on the preferred answers [00:37:32.940 --> 00:37:34.980] or things like that, either don't produce [00:37:34.980 --> 00:37:38.000] the same level of improvement or are unstable. [00:37:38.000 --> 00:37:42.420] We then decided to try to scale these results up [00:37:42.420 --> 00:37:45.540] to more harder, more involved problems. [00:37:45.540 --> 00:37:48.500] The first thing we did is this summarization task. [00:37:48.500 --> 00:37:52.140] The goal is to provide summarizations of some Reddit posts [00:37:52.140 --> 00:37:54.700] and dialogue tasks of the Tropic Helpful [00:37:54.700 --> 00:37:58.800] and Harmless dataset, publicly released datasets. [00:37:58.800 --> 00:38:00.220] And kind of again, what we see there [00:38:00.220 --> 00:38:03.060] is that across the board, DPO either matches [00:38:03.060 --> 00:38:06.220] or outperforms all other baselines. [00:38:06.220 --> 00:38:08.660] And particularly, for example, in the summarization case, [00:38:08.660 --> 00:38:11.360] the PPO model is almost twice as big. [00:38:12.360 --> 00:38:16.200] So another interesting experiment that we ran recently [00:38:16.200 --> 00:38:19.240] is evaluating the generalization capabilities [00:38:19.240 --> 00:38:22.040] of the DPO policy, because essentially, [00:38:22.040 --> 00:38:25.480] the PPO-trained approaches sample a lot of additional data [00:38:25.480 --> 00:38:28.480] and have the capability to train a lot of additional data, [00:38:28.480 --> 00:38:30.560] while DPO is fully only using [00:38:30.560 --> 00:38:32.720] the offline dataset of preferences. [00:38:32.720 --> 00:38:36.160] So what we did here is we took the summarization models [00:38:36.160 --> 00:38:38.680] that we presented in the previous slides. [00:38:38.680 --> 00:38:42.440] Those are the first two graphs on the left. [00:38:42.440 --> 00:38:44.240] Sampled at different temperatures [00:38:44.240 --> 00:38:45.960] and evaluated within distribution. [00:38:45.960 --> 00:38:47.400] As you can see, within distribution, [00:38:47.400 --> 00:38:48.840] they're quite comparable. [00:38:48.840 --> 00:38:51.600] And then we evaluated them on out-of-distribution data, [00:38:51.600 --> 00:38:53.360] particularly summarization of news, [00:38:53.360 --> 00:38:55.680] CNN, and Daily Mail articles. [00:38:55.680 --> 00:38:57.580] And we do see quite a significant drop [00:38:57.580 --> 00:38:59.360] when we take these models out of distribution. [00:38:59.360 --> 00:39:02.780] But the interesting thing is that the DPO policy [00:39:02.780 --> 00:39:05.960] still generalizes just as well, or even perhaps better, [00:39:05.960 --> 00:39:07.880] than the PPO-trained policy, [00:39:07.880 --> 00:39:09.400] even though the PPO-trained policy [00:39:09.400 --> 00:39:11.960] is trained on a lot more additionally sampled data. [00:39:11.960 --> 00:39:16.400] However, I think the strongest validation [00:39:16.400 --> 00:39:18.120] of this algorithm and its capabilities [00:39:18.120 --> 00:39:19.960] are the strong open-source models [00:39:19.960 --> 00:39:22.280] that have been trained by the community, [00:39:22.280 --> 00:39:25.060] and this is only a selection of those. [00:39:25.060 --> 00:39:27.400] There are others we couldn't fit on the slide. [00:39:27.400 --> 00:39:28.920] And if you couldn't go through all of them, [00:39:28.920 --> 00:39:31.400] you see that especially some of the recent ones [00:39:31.400 --> 00:39:33.380] do match, or sometimes even outperform, [00:39:33.380 --> 00:39:36.560] Chad GPT on some broad benchmarks. [00:39:36.560 --> 00:39:38.020] Another point to mention here is [00:39:38.020 --> 00:39:39.920] this is only within the language domain, [00:39:39.920 --> 00:39:42.680] but recently works have done this training [00:39:42.680 --> 00:39:44.520] state-of-the-art text-to-image models [00:39:44.520 --> 00:39:48.200] with the DPO algorithm, used for vision language models, [00:39:48.200 --> 00:39:50.420] and also using for multi-step control as well. [00:39:50.420 --> 00:39:52.400] So this is going beyond languages. [00:39:52.400 --> 00:39:54.840] It is becoming kind of a paradigm of alignment. [00:39:54.840 --> 00:39:57.640] So in conclusion, I want to point out [00:39:57.640 --> 00:40:00.040] that kind of the DPO removes the complicated, [00:40:00.040 --> 00:40:03.460] expensive RO training loop from ROHF. [00:40:03.460 --> 00:40:06.400] It's a simple, stable, and computationally cheaper [00:40:06.400 --> 00:40:09.800] than PPO, I think almost order of magnitude. [00:40:09.800 --> 00:40:11.200] And most importantly, it's also principled. [00:40:11.200 --> 00:40:13.000] You're optimizing the exact same objective. [00:40:13.000 --> 00:40:14.360] It's not a hack. [00:40:14.360 --> 00:40:16.400] It's optimizing for the exact same thing. [00:40:16.400 --> 00:40:19.360] And yeah, as you've seen, others are training [00:40:19.360 --> 00:40:21.240] a lot of state-of-the-art models. [00:40:21.240 --> 00:40:22.920] We've been achieving pretty strong results, [00:40:22.920 --> 00:40:24.700] so you should do as well. [00:40:24.700 --> 00:40:26.640] If you want to learn more about it, [00:40:26.640 --> 00:40:28.720] you can come talk to us at our poster, [00:40:28.720 --> 00:40:32.840] and we have publicly opened our code implementation. [00:40:32.840 --> 00:40:33.920] You can find it on GitHub, [00:40:33.920 --> 00:40:36.300] and you can check our paper on archive as well. [00:40:36.300 --> 00:40:37.500] Thank you very much. [00:40:37.500 --> 00:40:39.840] (audience applauding) [00:40:39.840 --> 00:40:41.980] - So DPO is interesting because it promises [00:40:41.980 --> 00:40:43.760] to be simpler than PPO. [00:40:43.760 --> 00:40:45.880] It's definitely easier and cheaper to train, [00:40:45.880 --> 00:40:48.000] and there are a bunch of models already emerging, [00:40:48.000 --> 00:40:49.560] being trained on it. [00:40:49.560 --> 00:40:51.880] The main criticism that people seem to have [00:40:51.880 --> 00:40:54.280] is that it isn't performing as well [00:40:54.280 --> 00:40:56.960] in terms of alignments or results or benchmarks [00:40:56.960 --> 00:40:58.640] as PPO-trained models. [00:40:58.640 --> 00:41:00.120] But that still remains to be seen, [00:41:00.120 --> 00:41:03.780] whether that ease of use and cheapness of availability [00:41:03.780 --> 00:41:06.480] of data or whatever makes it so much better [00:41:06.480 --> 00:41:07.840] that it doesn't actually matter. [00:41:07.840 --> 00:41:10.360] So what happens in NeurIPS is that some papers [00:41:10.360 --> 00:41:11.560] are selected for oral sessions, [00:41:11.560 --> 00:41:14.540] and then everyone heads down to the poster hall [00:41:14.540 --> 00:41:18.180] where there's about 600 posters simultaneously presenting, [00:41:18.180 --> 00:41:20.240] including the people from the oral sessions. [00:41:20.240 --> 00:41:21.080] And this is what we did. [00:41:21.080 --> 00:41:24.240] We went down to talk to the paper authors [00:41:24.240 --> 00:41:25.920] after their oral session. [00:41:25.920 --> 00:41:28.560] So we're gonna hear them re-explain DPO in four minutes [00:41:28.560 --> 00:41:30.360] and then answer a bunch of Q&A. [00:41:30.360 --> 00:41:32.280] But you can also get a sense of how chaotic [00:41:32.280 --> 00:41:34.440] and noisy it is in that poster session. [00:41:34.440 --> 00:41:36.320] It's just a mess and I love it. [00:41:36.320 --> 00:41:39.280] - I'm talking about direct reference optimization here. [00:41:39.280 --> 00:41:41.240] RLHF is really cool. [00:41:41.240 --> 00:41:43.880] You get chat GPT from GPT-3 using RLHF. [00:41:43.880 --> 00:41:45.120] If you've never heard of chat GPT, [00:41:45.120 --> 00:41:45.980] you might wanna look it up. [00:41:45.980 --> 00:41:47.040] It's really important. [00:41:47.040 --> 00:41:49.960] RLHF is complicated. [00:41:49.960 --> 00:41:51.480] It's really hard. [00:41:51.480 --> 00:41:54.560] When you start with reference data distribution, [00:41:54.560 --> 00:41:56.980] you usually have to do some kind of RL process on top of it. [00:41:56.980 --> 00:41:58.320] And RL is hard to implement [00:41:58.320 --> 00:42:00.760] because it has a lot of moving components. [00:42:00.760 --> 00:42:01.960] You have to sample the model a lot. [00:42:01.960 --> 00:42:03.400] You have to train a value function. [00:42:03.400 --> 00:42:06.520] You have to do a lot of magic trickery to get it to work. [00:42:06.520 --> 00:42:09.680] Our hope was that, can we make this simpler? [00:42:09.680 --> 00:42:13.000] And that's where we designed DPO. [00:42:13.000 --> 00:42:15.000] Just to give a brief overview of RLHF. [00:42:15.000 --> 00:42:18.600] It starts off with some distribution or some model [00:42:18.600 --> 00:42:19.600] that you have already trained, [00:42:19.600 --> 00:42:21.260] which is usually reasonably good. [00:42:21.260 --> 00:42:24.680] I'm thinking of GPT-3, which is already pretty good. [00:42:24.680 --> 00:42:26.620] They like some preference data on top of it. [00:42:26.620 --> 00:42:29.880] So you have instruction, two pairs of completions, [00:42:29.880 --> 00:42:31.400] and the human labels, which one is preferred [00:42:31.400 --> 00:42:32.840] and which one is dispreferred. [00:42:32.840 --> 00:42:37.240] With this preference data, you first fit a reward model. [00:42:37.240 --> 00:42:39.160] The reward model will give you, [00:42:39.160 --> 00:42:41.040] it's basically telling you which preferred model [00:42:41.040 --> 00:42:42.140] should have a higher reward, [00:42:42.140 --> 00:42:45.060] and the dispreferred completion should have a lower reward. [00:42:45.060 --> 00:42:47.900] And this is a simple classification problem. [00:42:47.900 --> 00:42:49.840] It's very straightforward. [00:42:49.840 --> 00:42:53.180] Now, given this reward model, you wanna do RL on top of it. [00:42:53.180 --> 00:42:55.760] So you wanna generate completions which are good. [00:42:55.760 --> 00:42:58.320] And the way you set it up is you maximize [00:42:58.320 --> 00:43:01.280] the expected reward under a KL constraint [00:43:01.280 --> 00:43:03.200] to the initial distribution that it started with. [00:43:03.200 --> 00:43:04.800] Now, why the KL constraint? [00:43:04.800 --> 00:43:07.120] The models can degenerate very, very quickly. [00:43:07.120 --> 00:43:12.520] And usually what you wanna do is stay close to these models [00:43:12.520 --> 00:43:13.760] so you don't degenerate, [00:43:13.760 --> 00:43:16.000] and you do not exploit the reward model. [00:43:16.000 --> 00:43:17.020] The reward models are trained [00:43:17.020 --> 00:43:18.840] on a very little amount of data, [00:43:18.840 --> 00:43:20.680] and these are very easy to exploit. [00:43:20.680 --> 00:43:22.580] So that's why this KL constraint is important. [00:43:22.580 --> 00:43:24.140] This is a traditional RLHF pipeline. [00:43:24.140 --> 00:43:26.960] This is what exactly was used for CHAT-GPT, [00:43:26.960 --> 00:43:28.320] initially at least. [00:43:28.320 --> 00:43:31.380] And it's very complicated to do with PPO. [00:43:31.380 --> 00:43:32.800] It's hard to get it right. [00:43:32.800 --> 00:43:36.840] Now, our contribution is the direct reference optimization. [00:43:36.840 --> 00:43:38.400] And the way this works is that [00:43:38.400 --> 00:43:40.360] it turns out for this optimization, [00:43:40.360 --> 00:43:43.200] there is an exact optimal solution. [00:43:43.200 --> 00:43:44.560] This optimal solution, if you've seen [00:43:44.560 --> 00:43:47.040] Boltzmann distribution before, very simple. [00:43:47.040 --> 00:43:48.480] You take your reference distribution, [00:43:48.480 --> 00:43:51.360] you up-weight the good things by exponentiated reward, [00:43:51.360 --> 00:43:53.760] and you down-weight the things by exponentiated reward, [00:43:53.760 --> 00:43:54.600] which are bad. [00:43:54.600 --> 00:43:57.200] Like, it's just the exponentiated reward weighted [00:43:57.200 --> 00:43:59.400] for the reference distribution. [00:43:59.400 --> 00:44:02.920] Now, unfortunately, this is intractable. [00:44:02.920 --> 00:44:03.740] Why? [00:44:03.740 --> 00:44:05.540] Because the partition function is intractable. [00:44:05.540 --> 00:44:07.780] So you cannot actually compute this distribution. [00:44:07.780 --> 00:44:10.560] But as it will turn out, this is not gonna matter. [00:44:10.560 --> 00:44:12.560] So our main contribution is that [00:44:12.560 --> 00:44:15.120] you can actually rewrite the reward [00:44:15.120 --> 00:44:17.360] in terms of the policy itself. [00:44:17.360 --> 00:44:20.160] So simple algebra, you write the reward [00:44:20.160 --> 00:44:22.480] in terms of beta log pi pi by ref. [00:44:22.480 --> 00:44:24.760] So this is just simple algebra here. [00:44:24.760 --> 00:44:27.120] Take your time, just look at it for a second. [00:44:27.120 --> 00:44:29.080] You're just rearranging terms. [00:44:29.080 --> 00:44:30.360] But the thing is that you still have [00:44:30.360 --> 00:44:33.920] a beta log partition function, which is still intractable. [00:44:33.920 --> 00:44:36.480] Now, the key thing is, we can fit this reward [00:44:36.480 --> 00:44:38.240] using the same classification loss [00:44:38.240 --> 00:44:41.680] that we were using earlier over here. [00:44:41.680 --> 00:44:44.160] But the nice thing is, it depends upon the difference [00:44:44.160 --> 00:44:46.000] between the reward for the good completion [00:44:46.000 --> 00:44:47.680] and the bad completion. [00:44:47.680 --> 00:44:50.720] And the partition function actually cancels out. [00:44:50.720 --> 00:44:52.000] If you look at the partition function, [00:44:52.000 --> 00:44:53.840] it only depends on the instruction. [00:44:53.840 --> 00:44:58.240] So it only ends up depending on this quantity. [00:44:58.240 --> 00:45:00.640] And this is exactly how you get the DPO loss. [00:45:00.640 --> 00:45:03.120] You're plugging in this implied reward function [00:45:03.120 --> 00:45:04.960] into the classification loss, [00:45:04.960 --> 00:45:06.360] and you get the DPO classification, [00:45:06.360 --> 00:45:08.360] which is directly in terms of your policy [00:45:08.360 --> 00:45:10.320] that is being fine-tuned. [00:45:10.320 --> 00:45:12.800] So you no longer need to do an explicit reward model [00:45:12.800 --> 00:45:14.520] where you're learning a different reward model. [00:45:14.520 --> 00:45:17.480] You do not have to do any RL optimization after that. [00:45:17.480 --> 00:45:19.360] What you're doing is exactly, [00:45:19.360 --> 00:45:20.400] you're fitting this reward model, [00:45:20.400 --> 00:45:22.400] and you immediately get the optimal policy [00:45:22.400 --> 00:45:25.040] for that reward model without doing any RL. [00:45:25.040 --> 00:45:27.760] And that's like the main pitch for DPO. [00:45:27.760 --> 00:45:30.320] Any questions, anything I can explain further? [00:45:30.320 --> 00:45:35.320] - So that means you don't have to learn any reward function. [00:45:35.320 --> 00:45:37.280] - You don't have to? [00:45:37.280 --> 00:45:40.160] But the policy already implies a reward. [00:45:40.160 --> 00:45:42.000] Yes, exactly. [00:45:42.000 --> 00:45:42.840] Does that make sense? [00:45:42.840 --> 00:45:44.000] - I'll tell you the reason. [00:45:44.000 --> 00:45:44.840] - Skeptical. [00:45:44.840 --> 00:45:45.680] - How did I do? [00:45:45.680 --> 00:45:46.520] - Yes, yes. [00:45:46.520 --> 00:45:47.360] - Not the-- [00:45:47.360 --> 00:45:48.800] - This is not the actual, but this is-- [00:45:48.800 --> 00:45:49.760] - Specific reward model. [00:45:49.760 --> 00:45:50.600] - Yes. [00:45:50.600 --> 00:45:52.560] - What about the data collection aspect of it? [00:45:52.560 --> 00:45:53.400] - Sorry? [00:45:53.400 --> 00:45:55.600] - What about the data collection aspect of RLHF? [00:45:55.600 --> 00:45:56.640] - That's a great question. [00:45:56.640 --> 00:46:00.640] So, PPO usually samples more completions online, [00:46:00.640 --> 00:46:02.240] and you don't have to do any of that. [00:46:02.240 --> 00:46:04.120] You only have to sample the preference data set [00:46:04.120 --> 00:46:05.920] in the beginning, which we use for-- [00:46:05.920 --> 00:46:08.040] - The same way we do PPO and RLHF. [00:46:08.040 --> 00:46:09.920] - But how do you know that your preference data set [00:46:09.920 --> 00:46:11.680] is as good as-- [00:46:11.680 --> 00:46:13.400] - We use the exact same preference data set [00:46:13.400 --> 00:46:15.800] for RLHF and for DPO. [00:46:15.800 --> 00:46:18.200] - It's like a mathematical shortcut. [00:46:18.200 --> 00:46:20.720] - Yeah, I'm sorry, I'm sorry. [00:46:20.720 --> 00:46:21.560] - Like the-- [00:46:21.560 --> 00:46:22.800] - The new loss function. [00:46:22.800 --> 00:46:25.160] - From the fact, like, you train this model [00:46:25.160 --> 00:46:26.120] on some data distribution-- [00:46:26.120 --> 00:46:26.960] - Yes. [00:46:26.960 --> 00:46:27.880] - But when you explore-- [00:46:27.880 --> 00:46:28.720] - Yes. [00:46:28.720 --> 00:46:29.560] - It might go out of distribution-- [00:46:29.560 --> 00:46:30.400] - Yes, yes, yes. [00:46:30.400 --> 00:46:31.600] - Which kind of, like, limits the policy. [00:46:31.600 --> 00:46:32.440] - Yes, exactly. [00:46:32.440 --> 00:46:34.520] - You see, that is a major reason the drop is-- [00:46:34.520 --> 00:46:35.800] - In general, PPO also has, like, [00:46:35.800 --> 00:46:37.040] a high variance estimator, [00:46:37.040 --> 00:46:38.680] so the optimization is never perfect, [00:46:38.680 --> 00:46:40.480] whereas with DPO, you know for a fact [00:46:40.480 --> 00:46:42.480] that it's an optimal policy. [00:46:42.480 --> 00:46:44.760] So, like, it's very, very similar, [00:46:44.760 --> 00:46:46.680] like, you know for a fact that it's optimal. [00:46:46.680 --> 00:46:48.560] But in general, like, if you have [00:46:48.560 --> 00:46:51.320] a very well-fine-tuned PPO pipeline, [00:46:51.320 --> 00:46:54.440] it will usually work reasonably similarly. [00:46:54.440 --> 00:46:57.160] But, yeah, you don't have to do any of that. [00:46:57.160 --> 00:46:59.840] - Yeah, so, essentially, one of the things that-- [00:46:59.840 --> 00:47:00.680] - What would you-- [00:47:00.680 --> 00:47:01.520] - Assumption. [00:47:01.520 --> 00:47:02.360] - Okay, so-- [00:47:02.360 --> 00:47:04.080] - This is not an assumption. [00:47:04.080 --> 00:47:06.600] This is the actual solution, this is not an assumption. [00:47:06.600 --> 00:47:07.440] - Right. [00:47:07.440 --> 00:47:08.280] - Yeah. [00:47:08.280 --> 00:47:09.120] - It's generic enough-- [00:47:09.120 --> 00:47:10.440] - In terms of mathematical form? [00:47:10.440 --> 00:47:11.280] - Yeah. [00:47:11.280 --> 00:47:14.280] - But I was wondering because this term, rewards, [00:47:14.280 --> 00:47:16.400] for example, does it match the definition [00:47:16.400 --> 00:47:17.640] of reward, because you could write [00:47:17.640 --> 00:47:20.000] any exponential function here. [00:47:20.000 --> 00:47:20.840] - Yeah. [00:47:20.840 --> 00:47:21.680] - And it's been called reward, [00:47:21.680 --> 00:47:24.960] but does it match, like, the reward definition? [00:47:24.960 --> 00:47:26.920] - In this optimal solution, you assume [00:47:26.920 --> 00:47:29.040] there's a reward function that has been given to you. [00:47:29.040 --> 00:47:29.880] - Oh, I see. [00:47:29.880 --> 00:47:30.720] - And, yeah. [00:47:30.720 --> 00:47:31.920] - It's a sequence of actions that's asked. [00:47:31.920 --> 00:47:32.760] - Yeah, yeah, yeah. [00:47:32.760 --> 00:47:35.480] - Some constant times the log ratio of some-- [00:47:35.480 --> 00:47:38.480] - And, I mean, overall, if I look at the experiments, [00:47:38.480 --> 00:47:41.320] let's look at the real-world data sets, [00:47:41.320 --> 00:47:43.360] like, I mean, we try out, like, summarizations, [00:47:43.360 --> 00:47:45.160] I try out, like, single-term dialogue, [00:47:45.160 --> 00:47:46.320] and it all works great. [00:47:46.320 --> 00:47:50.200] You never had to do, like, any online exploration [00:47:50.200 --> 00:47:52.760] or of any form, and, like, PPO relatively works [00:47:52.760 --> 00:47:54.960] better than PPO, or very similarly to it. [00:47:54.960 --> 00:47:55.800] I think-- [00:47:55.800 --> 00:47:57.320] - Can I, what's the methodology, Nick? [00:47:57.320 --> 00:48:00.480] You take a base model, and you function it with DPO? [00:48:00.480 --> 00:48:03.560] - So we take a same base model. [00:48:03.560 --> 00:48:04.400] - Yeah. [00:48:04.400 --> 00:48:05.520] - We have the same preference data set. [00:48:05.520 --> 00:48:06.360] - Yeah. [00:48:06.360 --> 00:48:07.920] - First, we fit a reward model for PPO, [00:48:07.920 --> 00:48:09.240] and then you do RL for it. [00:48:09.240 --> 00:48:10.760] - Okay, so completely comparable. [00:48:10.760 --> 00:48:12.000] - Yes. [00:48:12.000 --> 00:48:14.400] In general, like, I mean, we tried to, like, [00:48:14.400 --> 00:48:17.240] reuse, like, people's already pre-trained models [00:48:17.240 --> 00:48:19.600] for RLHF, but we looked at their pipeline, [00:48:19.600 --> 00:48:22.160] and it was exactly the same, because if we do it, [00:48:22.160 --> 00:48:23.840] like, there's always a case that it's possible [00:48:23.840 --> 00:48:25.560] that we didn't tune it well enough. [00:48:25.560 --> 00:48:26.920] So, like, we tried to, like, take models [00:48:26.920 --> 00:48:29.480] that are trained using RLHF and try to compare [00:48:29.480 --> 00:48:32.440] to them directly, but they're trained on the same data sets. [00:48:32.440 --> 00:48:36.680] Very strong models have been trained using DPO. [00:48:36.680 --> 00:48:38.000] They're already being used. [00:48:38.000 --> 00:48:39.760] - Yeah, Zephyr is the one I know about. [00:48:39.760 --> 00:48:42.320] - Dulu, Mixtron models, if you ever look at it, [00:48:42.320 --> 00:48:44.320] were trained using DPO as well. [00:48:44.320 --> 00:48:45.560] - Oh, that's the Mixtron Strux? [00:48:45.560 --> 00:48:46.400] - Yes. [00:48:46.400 --> 00:48:47.240] - Okay. [00:48:47.240 --> 00:48:49.120] - They were trained using DPO as well, so if you guys-- [00:48:49.120 --> 00:48:50.560] - That came out, like, very recently. [00:48:50.560 --> 00:48:52.000] - Yes, that's why it's not on the poster, [00:48:52.000 --> 00:48:55.240] but, like, I mean, you guys, if you're thinking [00:48:55.240 --> 00:48:57.320] of fine-tuning using preferences, [00:48:57.320 --> 00:48:59.280] you should try to use DPO. [00:48:59.280 --> 00:49:01.040] - How much is the efficiency gain [00:49:01.040 --> 00:49:03.440] compared to a PPO process? [00:49:03.440 --> 00:49:05.800] - A lot, because you only have to do one step. [00:49:05.800 --> 00:49:06.640] - Yeah. [00:49:06.640 --> 00:49:08.760] - It uses the same set of preferences. [00:49:08.760 --> 00:49:10.880] - It's just a little one thing. [00:49:10.880 --> 00:49:12.720] - So, like, basically, no trade-offs? [00:49:12.720 --> 00:49:15.240] - I think more research-- [00:49:15.240 --> 00:49:17.280] - I'm looking for trade-offs, I cannot find any. [00:49:17.280 --> 00:49:18.920] - More research needs to be done. [00:49:18.920 --> 00:49:20.200] There are arguments to be made [00:49:20.200 --> 00:49:22.520] that PPO might do better in some cases, [00:49:22.520 --> 00:49:25.960] but it's unclear, like, we haven't personally [00:49:25.960 --> 00:49:27.280] seen any evidence yet. [00:49:27.280 --> 00:49:28.920] - I see, I see. [00:49:28.920 --> 00:49:30.040] Sorry, one more question before I-- [00:49:30.040 --> 00:49:31.400] - Go for it, yeah. [00:49:31.400 --> 00:49:33.920] - I noticed Chelsea Finn's a co-author. [00:49:33.920 --> 00:49:37.200] What guidance has she given, or, I'm curious-- [00:49:37.200 --> 00:49:39.720] - I mean, look, we're all in her lab. [00:49:39.720 --> 00:49:40.920] She's the one who selected us. [00:49:40.920 --> 00:49:42.680] She's the one who's providing the infrastructure. [00:49:42.680 --> 00:49:43.520] - Yeah. [00:49:43.520 --> 00:49:45.920] - Like, I mean, none of this would be possible without her. [00:49:45.920 --> 00:49:48.920] - I'm just curious, like, is there any interesting stories, [00:49:48.920 --> 00:49:50.880] any, like-- [00:49:50.880 --> 00:49:51.720] - Of course, yeah. [00:49:51.720 --> 00:49:53.040] - Good advice that she gave that, like, [00:49:53.040 --> 00:49:55.720] really inspired you, that you want to pass on to others? [00:49:55.720 --> 00:49:57.800] - I think when we started discussing the idea with her, [00:49:57.800 --> 00:50:00.800] she was very insistent that you should try to push this, [00:50:00.800 --> 00:50:04.960] because this is a nice idea, but if you sit on it, [00:50:04.960 --> 00:50:07.080] somebody might do it, or it might fade out of relevance. [00:50:07.080 --> 00:50:09.400] So this paper came about in three weeks [00:50:09.400 --> 00:50:11.320] before the NeurIPS deadline. [00:50:11.320 --> 00:50:13.640] So we had to push really hard. [00:50:13.640 --> 00:50:15.760] - And how did you come up with the idea, you said? [00:50:15.760 --> 00:50:17.840] - I mean, we were looking at this kind of equation [00:50:17.840 --> 00:50:20.200] before Rafael did a bit of algebra, and say, [00:50:20.200 --> 00:50:22.840] oh, maybe we can just, like, completely skip the RL part, [00:50:22.840 --> 00:50:24.360] if we, like, look at this thing. [00:50:24.360 --> 00:50:26.440] So, like, I mean, we were playing around. [00:50:26.440 --> 00:50:28.280] Generally speaking, like, [00:50:28.280 --> 00:50:30.280] there's a reward estimation step. [00:50:30.280 --> 00:50:33.360] Whenever you're learning three things in a sequence, [00:50:33.360 --> 00:50:35.760] if you can statistically remove one of the steps-- [00:50:35.760 --> 00:50:36.600] - Yeah, you gain a lot. [00:50:36.600 --> 00:50:40.120] - Yeah, so that's where we, the motivation usually comes from. [00:50:40.120 --> 00:50:42.080] - Has John Shulman commented on this? [00:50:42.080 --> 00:50:43.880] - Yes. [00:50:43.880 --> 00:50:45.560] - What did he say? [00:50:45.560 --> 00:50:48.080] - I mean, he tried it, he said it works, [00:50:48.080 --> 00:50:50.080] but there's some questions about, like, [00:50:50.080 --> 00:50:51.760] they might be treating their reward models [00:50:51.760 --> 00:50:54.520] on more than binary pairwise preferences, [00:50:54.520 --> 00:50:57.120] so, like, it's not immediately clear [00:50:57.120 --> 00:50:58.840] how to extend that using DPO. [00:50:58.840 --> 00:51:00.520] - Like multiple choice? [00:51:00.520 --> 00:51:02.240] - Unclear, they obviously did not tell me, [00:51:02.240 --> 00:51:03.280] like, I mean, what they're doing, [00:51:03.280 --> 00:51:04.520] but, like, they're training on more [00:51:04.520 --> 00:51:05.640] than just pairwise preferences, [00:51:05.640 --> 00:51:07.200] and they might still want to do, like, RL, RL. [00:51:07.200 --> 00:51:09.640] - You can decompose most things into pairwise. [00:51:09.640 --> 00:51:11.640] - Yeah, yeah, that's kind of what I assume, [00:51:11.640 --> 00:51:13.840] but, like, I don't know what exactly they're doing. [00:51:13.840 --> 00:51:15.800] So there's a situation where they might be conditioning [00:51:15.800 --> 00:51:17.280] their reward model on something more [00:51:17.280 --> 00:51:19.040] than what your policy is conditioned on. [00:51:19.040 --> 00:51:21.720] - That means my wrap of Y and X is a constant. [00:51:21.720 --> 00:51:22.560] - Yeah. [00:51:22.560 --> 00:51:23.400] - And so, this-- [00:51:23.400 --> 00:51:24.240] - It's all I got, thank you. [00:51:24.240 --> 00:51:25.080] - Is zero. [00:51:25.080 --> 00:51:27.680] - The other best paper runner-up that we'll talk about [00:51:27.680 --> 00:51:30.120] is scaling data-constrained language models. [00:51:30.120 --> 00:51:32.040] In other words, the datablations paper. [00:51:32.040 --> 00:51:33.920] And this is a scaling laws paper, [00:51:33.920 --> 00:51:36.680] kind of in the vein of the chinchilla paper, [00:51:36.680 --> 00:51:39.320] but done with a different assumption in mind. [00:51:39.320 --> 00:51:41.160] Instead of holding compute constant [00:51:41.160 --> 00:51:43.120] or holding parameter count constant, [00:51:43.120 --> 00:51:46.000] here we are running into the real-world problem [00:51:46.000 --> 00:51:47.520] of data constraint. [00:51:47.520 --> 00:51:49.880] So given that you have a fixed amount of data, [00:51:49.880 --> 00:51:52.080] what should you do to pre-train your models? [00:51:52.080 --> 00:51:53.440] This kind of paper tends to be [00:51:53.440 --> 00:51:54.920] a very expensive paper to write, [00:51:54.920 --> 00:51:57.160] just because you have to do so many ablations. [00:51:57.160 --> 00:51:59.760] Here it's notable that HuggingFace has created this [00:51:59.760 --> 00:52:01.840] and open-sourced it, both models and datasets. [00:52:01.840 --> 00:52:03.680] So kudos to HuggingFace. [00:52:03.680 --> 00:52:05.360] - Hi, I'm Niklas, and I'm presenting [00:52:05.360 --> 00:52:07.520] scaling data-constrained language models. [00:52:07.520 --> 00:52:11.560] The premise for this work is that we are data-constrained. [00:52:11.560 --> 00:52:13.920] Here's a plot from prior work that estimates [00:52:13.920 --> 00:52:16.800] that given their definition of high-quality language data, [00:52:16.800 --> 00:52:18.700] we're going to be exhausted next year. [00:52:18.700 --> 00:52:21.880] And what they mean with high-quality language data [00:52:21.880 --> 00:52:23.880] is data such as papers and books. [00:52:23.880 --> 00:52:26.960] There's other sources, like code, [00:52:26.960 --> 00:52:29.280] however it's unclear how useful it actually is [00:52:29.280 --> 00:52:30.600] for large language models. [00:52:31.720 --> 00:52:33.120] And for low-resource languages, [00:52:33.120 --> 00:52:35.120] we are already hardcore data-constrained. [00:52:35.120 --> 00:52:42.040] The first solution we investigate is simply repeating data. [00:52:42.040 --> 00:52:43.520] It's important to mention here that, [00:52:43.520 --> 00:52:45.920] while it's pretty common to train for multiple epochs [00:52:45.920 --> 00:52:47.680] in most machine learning problems, [00:52:47.680 --> 00:52:50.360] for large language models, this has been very uncommon. [00:52:50.360 --> 00:52:52.840] In GPT-3, they write that data are sampled [00:52:52.840 --> 00:52:54.120] without replacement. [00:52:54.120 --> 00:52:56.480] In Palm, they say that they explicitly avoid [00:52:56.480 --> 00:52:58.480] repeating data in any subcomponent. [00:52:58.480 --> 00:53:00.560] And there was other work explicitly recommending [00:53:00.560 --> 00:53:02.400] against repeating any data [00:53:02.400 --> 00:53:05.040] when training large language models. [00:53:05.040 --> 00:53:07.160] So we ask, is it really that bad? [00:53:07.160 --> 00:53:12.640] To answer this question, we have three different setups. [00:53:12.640 --> 00:53:17.440] We start by simply training for a single epoch. [00:53:17.440 --> 00:53:19.280] Here, this is your usual training graph [00:53:19.280 --> 00:53:21.480] where we have the validation loss on the y-axis [00:53:21.480 --> 00:53:23.480] and the training tokens on the x-axis. [00:53:23.480 --> 00:53:25.840] And for all of those setups, there's nothing special here. [00:53:25.840 --> 00:53:28.380] Loss improves as we increase training. [00:53:28.380 --> 00:53:31.080] Now, what happens if we train for two epochs? [00:53:31.080 --> 00:53:33.320] Notably, the performance is around the same. [00:53:33.320 --> 00:53:35.600] So here, only half of the data is unique [00:53:35.600 --> 00:53:37.240] and it has to be repeated twice. [00:53:37.240 --> 00:53:40.360] So for the setup on the left, 28 billion tokens are unique [00:53:40.360 --> 00:53:42.160] and they're repeated for two epochs. [00:53:42.160 --> 00:53:46.560] Three, four, and it's still pretty similar. [00:53:46.560 --> 00:53:48.840] However, eventually, it starts to diverge. [00:53:48.840 --> 00:53:50.400] So we shouldn't train for too many epochs. [00:53:50.400 --> 00:53:53.880] At 44 epochs, literally just 1/44 of the data is unique [00:53:53.880 --> 00:53:55.320] and repeating 44 times. [00:53:55.320 --> 00:53:56.920] So that's like one billion tokens, [00:53:56.920 --> 00:53:59.720] one billion unique tokens for the setup on the left. [00:53:59.720 --> 00:54:01.760] And that obviously isn't very good. [00:54:01.760 --> 00:54:04.720] However, for a few repeats, performance is very similar, [00:54:04.720 --> 00:54:06.920] suggesting that we can scale a lot further [00:54:06.920 --> 00:54:08.560] with existing data constraints [00:54:08.560 --> 00:54:10.920] by simply repeating for large language models. [00:54:10.920 --> 00:54:14.520] This naturally leads to the question, [00:54:14.520 --> 00:54:16.040] how should we allocate compute [00:54:16.040 --> 00:54:18.440] when we are in that repeated regime? [00:54:18.440 --> 00:54:20.080] A quick reminder from last year, [00:54:20.080 --> 00:54:22.960] Chinchilla told us that when we're not repeating data, [00:54:22.960 --> 00:54:24.440] so in the single epoch regime, [00:54:24.440 --> 00:54:27.720] we should scale model size and training data equally [00:54:27.720 --> 00:54:28.860] in equal proportions. [00:54:28.860 --> 00:54:32.560] How does it look like when we're repeating? [00:54:32.560 --> 00:54:33.600] To investigate this, [00:54:33.600 --> 00:54:36.240] we train on 100 million unique tokens [00:54:36.240 --> 00:54:37.880] and vary the model size [00:54:37.880 --> 00:54:40.040] and the number of epochs over those tokens. [00:54:40.040 --> 00:54:43.800] Each model is depicted as one of those dots. [00:54:43.800 --> 00:54:45.440] And as we go towards the upper right, [00:54:45.440 --> 00:54:48.040] so more parameters and more epochs, [00:54:48.040 --> 00:54:50.460] loss improves as indicated by the contours. [00:54:51.720 --> 00:54:53.840] We put forth scaling equations [00:54:53.840 --> 00:54:56.540] to exactly predict this change in loss [00:54:56.540 --> 00:54:57.880] and how you should allocate [00:54:57.880 --> 00:54:59.880] when you're in that repeated regime. [00:54:59.880 --> 00:55:01.380] They're depicted on the right. [00:55:01.380 --> 00:55:04.960] Now, if we add in the efficient frontiers, [00:55:04.960 --> 00:55:07.120] the Chinchilla scaling loss efficient frontier [00:55:07.120 --> 00:55:09.400] extrapolated to multiple epochs [00:55:09.400 --> 00:55:11.120] corresponds to the dashed line. [00:55:11.120 --> 00:55:13.540] So here's just an equal scaling of parameters, [00:55:13.540 --> 00:55:15.400] an equal scaling of epochs. [00:55:15.400 --> 00:55:19.680] However, our fit suggests that data should be scaled faster [00:55:19.680 --> 00:55:21.080] when we're in that repeated regime. [00:55:21.080 --> 00:55:24.520] And this is seen by the line branching off below [00:55:24.520 --> 00:55:26.200] and eventually just fades away [00:55:26.200 --> 00:55:27.160] because at some point, [00:55:27.160 --> 00:55:29.240] you can't get more value out of your data, [00:55:29.240 --> 00:55:31.800] especially with just 100 million tokens. [00:55:31.800 --> 00:55:33.080] At some point, you're just, yeah, [00:55:33.080 --> 00:55:35.240] running out of value in those few tokens. [00:55:35.240 --> 00:55:42.060] Now we test our predictions at scale. [00:55:42.060 --> 00:55:43.400] Here we have two models, [00:55:43.400 --> 00:55:46.180] one allocated according to Chinchilla scaling loss [00:55:46.180 --> 00:55:49.200] and one allocated according to data constraint scaling. [00:55:49.200 --> 00:55:51.000] The one on the top is Chinchilla [00:55:51.000 --> 00:55:52.600] and the one on the bottom, [00:55:52.600 --> 00:55:56.360] indicated by the red star, is our allocated model. [00:55:56.360 --> 00:55:58.200] They both have the same number of flops [00:55:58.200 --> 00:56:00.580] and the same data budget of 25 billion tokens. [00:56:00.580 --> 00:56:04.920] And we see that by training with fewer parameters [00:56:04.920 --> 00:56:08.640] for more epochs, so 6.3 and 9.7 epochs, [00:56:08.640 --> 00:56:12.640] or 242 billion tokens, we get a better loss. [00:56:12.640 --> 00:56:15.660] But not only loss, we also test this [00:56:15.660 --> 00:56:17.460] in terms of downstream performance [00:56:17.460 --> 00:56:19.360] and get better downstream performance. [00:56:20.340 --> 00:56:22.640] As indicated by the column towards the right. [00:56:22.640 --> 00:56:27.260] This was repeating, and now we're going to look [00:56:27.260 --> 00:56:30.300] at complementary strategies to solve data constraints. [00:56:30.300 --> 00:56:36.180] One intuitive strategy is making use [00:56:36.180 --> 00:56:38.000] of that code data that we saw earlier. [00:56:38.000 --> 00:56:40.300] So can we simply fill up the missing data [00:56:40.300 --> 00:56:41.420] with code from GitHub? [00:56:41.420 --> 00:56:47.540] In addition, we evaluate filtering strategies. [00:56:47.540 --> 00:56:50.540] Specifically, we look at fuzzy deduplication [00:56:50.540 --> 00:56:51.860] and perplexity filtering. [00:56:51.860 --> 00:56:54.540] The idea here is, can we use a quality filter [00:56:54.540 --> 00:56:56.700] and then repeat to get better performance [00:56:56.700 --> 00:56:58.260] than with the initial data set? [00:56:58.260 --> 00:57:03.960] Here are the results. [00:57:03.960 --> 00:57:06.580] On the y-axis, we have the average performance [00:57:06.580 --> 00:57:09.360] across 19 natural language tasks. [00:57:09.360 --> 00:57:11.340] On the x-axis is the data budget. [00:57:11.340 --> 00:57:14.940] So towards the left, we have 100% available data, [00:57:14.940 --> 00:57:17.140] so we don't need to use any of those strategies. [00:57:17.140 --> 00:57:18.540] But as we go to the right, [00:57:18.540 --> 00:57:20.260] our data budget is smaller and smaller, [00:57:20.260 --> 00:57:21.540] and we need to repeat data [00:57:21.540 --> 00:57:23.300] or fill the missing data with code. [00:57:23.300 --> 00:57:25.860] Starting with the purple line, [00:57:25.860 --> 00:57:27.660] we can confirm our findings from earlier [00:57:27.660 --> 00:57:29.400] that also in terms of downstream performance, [00:57:29.400 --> 00:57:31.940] roughly four epochs seems like a good trade-off. [00:57:31.940 --> 00:57:34.900] So at 25% data budget, we have to repeat four times, [00:57:34.900 --> 00:57:36.400] corresponding to four epochs. [00:57:36.400 --> 00:57:39.700] And then eventually, if you train for too many epochs, [00:57:39.700 --> 00:57:40.980] it drops quite a bit. [00:57:40.980 --> 00:57:43.700] So you have to be careful with repeating. [00:57:43.700 --> 00:57:45.860] The red line corresponds to filling missing data [00:57:45.860 --> 00:57:46.940] with Python code. [00:57:46.940 --> 00:57:49.460] Similar to the repeating line, [00:57:49.460 --> 00:57:51.780] we see that we can make up [00:57:51.780 --> 00:57:53.460] for a lot of natural language data with code [00:57:53.460 --> 00:57:55.620] without a drop in natural language performance. [00:57:55.620 --> 00:57:57.260] So these are all natural language tasks, [00:57:57.260 --> 00:58:00.020] and it seems like coding data is helpful for some of them. [00:58:00.020 --> 00:58:02.620] We even see spikes on some of these tasks [00:58:02.620 --> 00:58:03.920] as soon as code is added. [00:58:03.920 --> 00:58:07.660] Finally, we investigate the filtering strategies. [00:58:07.660 --> 00:58:10.260] We find that quality filtering, then repeating, [00:58:10.260 --> 00:58:12.740] can be much better than the data set to start with. [00:58:12.740 --> 00:58:14.500] So here, the yellow star at the top [00:58:14.500 --> 00:58:16.420] corresponds to perplexity filtering, [00:58:16.420 --> 00:58:18.580] and then repeating for two epochs. [00:58:18.580 --> 00:58:22.580] The orange star corresponds to fuzzy deduplication [00:58:22.580 --> 00:58:25.100] towards the right, and we find that you have to be careful [00:58:25.100 --> 00:58:26.420] with too much deduplication, [00:58:26.420 --> 00:58:28.460] because it can lead to a worse model [00:58:28.460 --> 00:58:30.060] by limiting your available data. [00:58:30.060 --> 00:58:40.540] Now I'll go through the takeaways. [00:58:40.540 --> 00:58:43.180] The first takeaway is that repeating data [00:58:43.180 --> 00:58:44.180] is generally fine. [00:58:44.700 --> 00:58:46.820] For many setups, roughly four epochs [00:58:46.820 --> 00:58:49.020] seems to provide a good trade-off. [00:58:49.020 --> 00:58:50.500] However, there are diminishing returns, [00:58:50.500 --> 00:58:53.100] and you have to be careful with too many epochs. [00:58:53.100 --> 00:58:55.460] Next, adding code data is fine, [00:58:55.460 --> 00:58:58.140] even if you're only interested in natural language tasks. [00:58:58.140 --> 00:59:01.420] We find that 50% provides a good trade-off for most setups. [00:59:01.420 --> 00:59:04.780] Finally, quality filtering plus repeating [00:59:04.780 --> 00:59:07.380] can be a good strategy, and is often much better [00:59:07.380 --> 00:59:09.140] than the data set you started with, [00:59:09.140 --> 00:59:12.340] because the penalty from repeating is often much smaller [00:59:12.340 --> 00:59:15.780] than the additional gain you can get from quality filtering. [00:59:15.780 --> 00:59:19.260] And finally, I wanted to finish off with some other work [00:59:19.260 --> 00:59:21.140] that has made use of these findings [00:59:21.140 --> 00:59:22.620] in their large-language model training. [00:59:22.620 --> 00:59:24.860] So at the top, we have FinGBT, [00:59:24.860 --> 00:59:26.780] a large-language model for Finnish, [00:59:26.780 --> 00:59:29.100] where they only had 38 billion unique tokens, [00:59:29.100 --> 00:59:31.180] and they had to repeat them for eight epochs [00:59:31.180 --> 00:59:32.540] in order to be able to train [00:59:32.540 --> 00:59:33.820] a reasonable large-language model [00:59:33.820 --> 00:59:35.740] with 13 billion parameters. [00:59:35.740 --> 00:59:36.980] And there are several more [00:59:36.980 --> 00:59:39.700] that haven't made use of these findings. [00:59:39.860 --> 00:59:43.260] (audience applauding) [00:59:43.260 --> 00:59:45.540] The finding that training up to four epochs [00:59:45.540 --> 00:59:47.980] is almost as good as getting new data [00:59:47.980 --> 00:59:51.420] is pretty surprising, and actually directly counters [00:59:51.420 --> 00:59:54.340] a very famous paper called "One Epoch is All You Need." [00:59:54.340 --> 00:59:56.620] I actually read it to Aaron Komatsuzaki [00:59:56.620 --> 00:59:58.540] at the decibel party. [00:59:58.540 --> 01:00:00.820] And it's just surprising at this stage in ML [01:00:00.820 --> 01:00:03.260] that we still don't know some very basic questions [01:00:03.260 --> 01:00:06.780] around how many epochs we should train on a data set. [01:00:06.780 --> 01:00:08.460] I mean, I still think that we are [01:00:08.460 --> 01:00:10.660] surprisingly sample-efficient. [01:00:10.660 --> 01:00:13.220] The consensus is now between one to four epochs, [01:00:13.220 --> 01:00:15.740] sometimes, in some cases, maybe up to eight. [01:00:15.740 --> 01:00:16.860] But more importantly than that, [01:00:16.860 --> 01:00:18.180] I think this work is notable [01:00:18.180 --> 01:00:20.140] because it is the best example [01:00:20.140 --> 01:00:22.060] of what open-source AI research should look like, [01:00:22.060 --> 01:00:23.660] and of course, it's from Hugging Face. [01:00:23.660 --> 01:00:25.180] If you go to the GitHub repo, [01:00:25.180 --> 01:00:27.020] you can see not only their papers, [01:00:27.020 --> 01:00:30.020] but also very, very well-documented code [01:00:30.020 --> 01:00:31.580] showing exactly what they did [01:00:31.580 --> 01:00:33.020] and how they got their results, [01:00:33.020 --> 01:00:34.980] including the data set filtering. [01:00:34.980 --> 01:00:37.500] So just exemplary work of open-source AI, [01:00:37.500 --> 01:00:40.260] and no surprise that they won one of the best paper awards. [01:00:40.260 --> 01:00:42.580] However, I did not manage to catch up with them [01:00:42.580 --> 01:00:45.100] for a post-presentation interview, [01:00:45.100 --> 01:00:47.780] but I did go straight to the next session [01:00:47.780 --> 01:00:49.900] on QLORA with Tim Detmers. [01:00:49.900 --> 01:00:50.980] - I'm Tim. [01:00:50.980 --> 01:00:52.300] Today, I present QLORA, [01:00:52.300 --> 01:00:55.260] Efficient Fine-Tuning of Quantized Large Language Models. [01:00:55.260 --> 01:00:58.260] Language models have gotten a lot bigger [01:00:58.260 --> 01:00:59.780] and a lot more powerful, [01:00:59.780 --> 01:01:01.180] but they have become so big [01:01:01.180 --> 01:01:02.580] that it's actually quite difficult [01:01:02.580 --> 01:01:04.260] if you take a pre-trained model [01:01:04.260 --> 01:01:05.500] and you want to fine-tune it [01:01:05.500 --> 01:01:07.620] as sort of a normal researcher. [01:01:07.620 --> 01:01:09.900] Often, you need now a big GPU server, [01:01:09.900 --> 01:01:11.900] and most researchers don't have that. [01:01:11.900 --> 01:01:14.060] So with QLORA, what we worked on [01:01:14.060 --> 01:01:16.380] is reducing the memory requirements [01:01:16.380 --> 01:01:19.540] so that everybody can fine-tune large language models. [01:01:19.540 --> 01:01:21.060] The main contribution of QLORA [01:01:21.060 --> 01:01:24.260] is we compress neural networks to 4-bit, [01:01:24.260 --> 01:01:27.860] and we developed a new data type, 4-bit normal float, [01:01:27.860 --> 01:01:30.940] that can replicate 16-bit performance [01:01:30.940 --> 01:01:33.540] even though we compress the neural network to 4-bit. [01:01:34.340 --> 01:01:35.500] Before I talk about QLORA, [01:01:35.500 --> 01:01:37.220] I'll give you a little bit of background. [01:01:37.220 --> 01:01:40.260] So this work is about quantization, about compression. [01:01:40.260 --> 01:01:42.140] So we do, for example, quantization [01:01:42.140 --> 01:01:44.460] if we have a 32-bit float number, [01:01:44.460 --> 01:01:47.620] and we want to quantize it to a 4-bit integer. [01:01:47.620 --> 01:01:50.780] In this diagram, I have a histogram, [01:01:50.780 --> 01:01:54.300] which is equivalent to an int4 quantization [01:01:54.300 --> 01:01:56.060] with 16 different bins. [01:01:56.060 --> 01:01:58.380] And in red, I have the normal distribution. [01:01:58.380 --> 01:02:00.660] And if we want to quantize all the values [01:02:00.660 --> 01:02:04.380] in the normal distribution to a 4-bit integer, [01:02:04.380 --> 01:02:06.820] we need to reduce all these values to 16 different values. [01:02:06.820 --> 01:02:07.940] How do we do that? [01:02:07.940 --> 01:02:09.660] We find the empirical minimum [01:02:09.660 --> 01:02:11.980] and maximum range of the distribution, [01:02:11.980 --> 01:02:13.720] and then we slice this distribution [01:02:13.720 --> 01:02:17.300] in 16 different slices with equal width. [01:02:17.300 --> 01:02:19.860] Each of these slices is a quantization bin, [01:02:19.860 --> 01:02:21.740] and all the values contained [01:02:21.740 --> 01:02:24.700] of the normal distribution in this bin [01:02:24.700 --> 01:02:26.900] are quantized to the middle value of the bin. [01:02:26.900 --> 01:02:29.180] And with that, we can reduce all the values [01:02:29.180 --> 01:02:32.340] in the normal distribution just to 16 different values. [01:02:32.340 --> 01:02:34.280] And this is an int4 quantization. [01:02:34.280 --> 01:02:38.020] Now, if we do other quantizations with other data types, [01:02:38.020 --> 01:02:39.540] we have different ranges. [01:02:39.540 --> 01:02:41.340] And so what I do in my work is [01:02:41.340 --> 01:02:43.020] I generalize these data types [01:02:43.020 --> 01:02:45.620] by normalizing the range the data types take [01:02:45.620 --> 01:02:48.140] to the range minus one and one. [01:02:48.140 --> 01:02:50.500] This approach is also called a codebook, [01:02:50.500 --> 01:02:52.460] where you map an index [01:02:52.460 --> 01:02:54.660] to a particular values in the data type. [01:02:54.660 --> 01:02:57.300] And so if we have this codebook, [01:02:57.300 --> 01:03:00.820] there's a two-step recipe how we can quantize any tensor. [01:03:00.820 --> 01:03:03.060] And so we take the tensor X, [01:03:03.060 --> 01:03:05.180] then we normalize it into the range, [01:03:05.180 --> 01:03:09.260] oh, sorry, and we normalize it into the range minus one, one [01:03:09.260 --> 01:03:11.780] by dividing by the absolute maximum value. [01:03:11.780 --> 01:03:14.980] And then we go through each element in the tensor [01:03:14.980 --> 01:03:18.060] and find the closest value in the data type. [01:03:18.060 --> 01:03:20.100] We do that by doing a binary search [01:03:20.100 --> 01:03:22.860] on the sorted values in the data type. [01:03:22.860 --> 01:03:26.220] And with that, we can then quantize the entire tensor. [01:03:26.220 --> 01:03:29.500] Just to make this a little clearer, here's an example. [01:03:29.500 --> 01:03:33.220] This is a very unusual two-bit data type. [01:03:33.220 --> 01:03:37.900] It has the values minus one, 0.3, 0.5, and 1.0. [01:03:37.900 --> 01:03:40.740] The input tensor is 10, minus three, five, four. [01:03:40.740 --> 01:03:43.220] And now let's go through the steps of the recipe. [01:03:43.220 --> 01:03:46.460] So first we find the absolute maximum value, which is 10. [01:03:46.460 --> 01:03:50.980] We divide by it, we get one, minus 0.3, 0.5, 0.4. [01:03:50.980 --> 01:03:54.020] And then we find the closest value of these values [01:03:54.020 --> 01:03:56.300] for each element associated in the data type. [01:03:56.300 --> 01:04:00.020] We get one, 0.3, 0.5, 0.5. [01:04:00.020 --> 01:04:03.980] Then we find the associated index of these values. [01:04:03.980 --> 01:04:06.020] And this is now a two-bit representation. [01:04:06.020 --> 01:04:08.580] Now we can store it, and it's compressed. [01:04:08.580 --> 01:04:10.700] If we want to dequantize these values, [01:04:10.700 --> 01:04:12.820] we just do all the steps in reverse. [01:04:12.820 --> 01:04:15.820] So we look up the associated values in the data type, [01:04:15.820 --> 01:04:17.700] and then we denormalize by multiplying [01:04:17.700 --> 01:04:19.700] by the absolute maximum value of 10. [01:04:19.700 --> 01:04:22.020] It gives us 10, 3, 5, 5. [01:04:22.020 --> 01:04:25.220] And so if we compare input and output tensors, what we see [01:04:25.220 --> 01:04:27.820] is that we have two big errors. [01:04:27.820 --> 01:04:32.340] The minus 3 turned into a 3, and the 4 turned into a 5. [01:04:32.340 --> 01:04:33.860] These are quantization errors. [01:04:33.860 --> 01:04:36.260] And so the main challenge in quantization research [01:04:36.260 --> 01:04:40.860] is we want to compress a neural network with a low-precision [01:04:40.860 --> 01:04:44.060] data type, but we want to keep all the quantization errors [01:04:44.060 --> 01:04:45.020] minimal. [01:04:45.020 --> 01:04:47.060] If the quantization errors are large, [01:04:47.060 --> 01:04:49.300] we degrade the neural network performance, [01:04:49.300 --> 01:04:50.620] and we want to avoid that. [01:04:50.620 --> 01:04:52.780] And that's the main challenge. [01:04:52.780 --> 01:04:54.900] Let's talk a little bit about fine-tuning. [01:04:54.900 --> 01:04:56.620] Why is it so expensive? [01:04:56.620 --> 01:04:59.620] So the best way to look at it is to look at the cost [01:04:59.620 --> 01:05:01.380] per parameter in fine-tuning. [01:05:01.380 --> 01:05:04.260] And so the per-parameter cost for full fine-tuning [01:05:04.260 --> 01:05:08.060] is 16-bit for each weight, 16-bit for each weight [01:05:08.060 --> 01:05:12.860] gradient, and 64-bit if we use atom for each parameter. [01:05:12.860 --> 01:05:15.580] That gives us 12 bytes per parameter. [01:05:15.580 --> 01:05:17.060] And if you have a 70-billion model, [01:05:17.060 --> 01:05:20.380] that's 840 gigabytes of GPU memory. [01:05:20.380 --> 01:05:22.540] 36 consumer GPUs. [01:05:22.540 --> 01:05:24.600] That's a lot of memory. [01:05:24.600 --> 01:05:28.340] If you use lowering adapters, we get much more efficient. [01:05:28.340 --> 01:05:31.900] And so what we do there is we take a pre-trained model, [01:05:31.900 --> 01:05:33.060] we freeze it. [01:05:33.060 --> 01:05:36.300] Now we put some tiny layers on top of it, some adapters. [01:05:36.300 --> 01:05:39.420] And so if we fine-tune it, we do stochastic gradient descent [01:05:39.420 --> 01:05:42.220] through the frozen layers into the adapters, [01:05:42.220 --> 01:05:45.300] and we just update the adapters, not the main model. [01:05:45.300 --> 01:05:47.500] And so what that does is the weights still [01:05:47.500 --> 01:05:50.180] need 16-bits per value. [01:05:50.180 --> 01:05:52.660] But now all the other values that are updated, [01:05:52.660 --> 01:05:55.260] they're only a fraction of a bit on average. [01:05:55.260 --> 01:05:58.980] And so in total, we have 17.6 bits per parameter. [01:05:58.980 --> 01:06:01.860] That adds up to 150 gigabytes of memory, [01:06:01.860 --> 01:06:04.020] which is 8 consumer GPUs. [01:06:04.020 --> 01:06:06.100] Now with our development of Killora, [01:06:06.100 --> 01:06:08.340] we step in and go a step further. [01:06:08.340 --> 01:06:11.820] So now we take the pre-trained model, quantize it to 4-bit, [01:06:11.820 --> 01:06:13.580] and then put adapters on top. [01:06:13.580 --> 01:06:15.460] That reduces the average footprint [01:06:15.460 --> 01:06:19.580] to 5.2 bits per parameter, which is 46 gigabytes. [01:06:19.580 --> 01:06:22.180] And that fits into two consumer GPUs. [01:06:22.180 --> 01:06:24.620] Now the main challenge is we want [01:06:24.620 --> 01:06:27.260] to preserve the performance while doing this 4-bit [01:06:27.260 --> 01:06:27.880] compression. [01:06:27.880 --> 01:06:29.580] And that is the main challenge. [01:06:29.580 --> 01:06:31.700] So we have three innovations that [01:06:31.700 --> 01:06:35.100] improve the memory performance, but then also the precisions [01:06:35.100 --> 01:06:37.780] to reduce the quantization error. [01:06:37.780 --> 01:06:40.540] There's one part, page optimizers, [01:06:40.540 --> 01:06:41.540] I will not talk about. [01:06:41.540 --> 01:06:43.340] You can read about it in the paper. [01:06:43.340 --> 01:06:46.220] It's used to prevent memory spikes during fine-tuning [01:06:46.220 --> 01:06:49.980] if you have hit a large document during your fine-tuning run. [01:06:49.980 --> 01:06:51.680] The main contribution that we have [01:06:51.680 --> 01:06:54.020] is the 4-bit normal float data type. [01:06:54.020 --> 01:06:56.300] This is a data type that's information-theoretically [01:06:56.300 --> 01:06:57.140] optimal. [01:06:57.140 --> 01:06:58.980] And so you can think about it like this. [01:06:58.980 --> 01:07:01.500] So in the beginning, I showed you an in-flow quantization [01:07:01.500 --> 01:07:04.340] where the quantization bins have equal width. [01:07:04.340 --> 01:07:07.980] In a normal float data type, the bins have equal area. [01:07:07.980 --> 01:07:11.100] That means each slice has equal probability mass [01:07:11.100 --> 01:07:13.660] in the normal distribution. [01:07:13.660 --> 01:07:16.180] And that means the same amount of values [01:07:16.180 --> 01:07:18.140] are quantized into each bin. [01:07:18.140 --> 01:07:20.180] With that, each bin has equal amount of values, [01:07:20.180 --> 01:07:23.540] and it's information-theoretically optimal. [01:07:23.540 --> 01:07:26.460] Our second contribution is a little bit silly. [01:07:26.460 --> 01:07:27.700] It's double quantization. [01:07:27.700 --> 01:07:30.380] We do a quantization of the quantization. [01:07:30.380 --> 01:07:31.900] And so what does that look like? [01:07:31.900 --> 01:07:34.340] So in the normal quantization, we take the weight, [01:07:34.340 --> 01:07:38.020] quantize it, and now we get two pieces, the quantized weights, [01:07:38.020 --> 01:07:40.200] and then the absolute maximum constants. [01:07:40.200 --> 01:07:41.860] We have multiple constants because we [01:07:41.860 --> 01:07:44.580] slice the weight into blocks, and each block [01:07:44.580 --> 01:07:46.180] has its own constant. [01:07:46.180 --> 01:07:48.020] And so we get a matrix of constants. [01:07:48.020 --> 01:07:51.060] On average, these are 0.5 bits, and that's [01:07:51.060 --> 01:07:53.700] multiple gigabytes of GPU memory. [01:07:53.700 --> 01:07:56.700] And now we quantize those constants again. [01:07:56.700 --> 01:07:59.280] We save about 0.4 bits on average. [01:07:59.280 --> 01:08:02.340] And that is important if we want to fit large models [01:08:02.340 --> 01:08:08.140] into consumer GPUs, because otherwise they don't quite fit. [01:08:08.140 --> 01:08:09.900] And so these are the contributions. [01:08:09.900 --> 01:08:12.180] Now let's look at the results. [01:08:12.180 --> 01:08:13.900] So the main thing that we want is [01:08:13.900 --> 01:08:15.700] to replicate 16-bit performance. [01:08:15.700 --> 01:08:17.260] That was our main goal. [01:08:17.260 --> 01:08:19.620] And so what I have here is different LAMA [01:08:19.620 --> 01:08:21.020] models of different sizes. [01:08:21.020 --> 01:08:25.860] And we fine-tune on the FLAN2 instruction data set. [01:08:25.860 --> 01:08:29.220] We evaluate on MMLU accuracy. [01:08:29.220 --> 01:08:33.820] We have in pink the 16-bit baseline and BrainFloat16. [01:08:33.820 --> 01:08:36.740] And what we see now that the float data [01:08:36.740 --> 01:08:40.260] type, the regular float data type, 4-bit float in blue, [01:08:40.260 --> 01:08:42.740] doesn't quite replicate 16-bit performance. [01:08:42.740 --> 01:08:45.500] However, if you use our normal float data type, [01:08:45.500 --> 01:08:48.340] we get up to 16-bit performance. [01:08:48.340 --> 01:08:52.900] And so with that, we have now replicated 16-bit performance. [01:08:52.900 --> 01:08:54.700] In our papers, we have much more experiments [01:08:54.700 --> 01:08:56.900] that also have the same finding. [01:08:56.900 --> 01:08:59.220] But with that now, we are at the stage [01:08:59.220 --> 01:09:02.120] where we can very efficiently fine-tune [01:09:02.120 --> 01:09:05.300] very large language models with very little resources. [01:09:05.300 --> 01:09:07.940] And so now we go a step further and ask, [01:09:07.940 --> 01:09:10.260] can we build a high-quality chatbot [01:09:10.260 --> 01:09:13.660] now that we very quickly can explore all possibilities [01:09:13.660 --> 01:09:14.980] with cheap fine-tuning? [01:09:14.980 --> 01:09:19.180] And so through our experiments, we run over 1,000 experiments. [01:09:19.180 --> 01:09:21.300] We find a very good data set and build [01:09:21.300 --> 01:09:24.820] a chatbot called Gonako, which is a 4-bit data set. [01:09:24.820 --> 01:09:28.100] We create it by just fine-tuning on a single consumer [01:09:28.100 --> 01:09:30.460] GPU for 24 hours. [01:09:30.460 --> 01:09:33.420] And now we want to compare how good is a chatbot compared [01:09:33.420 --> 01:09:37.140] to other chatbots that are trained or fine-tuned in 16-bit. [01:09:37.140 --> 01:09:39.540] And so we have a tournament-style setup [01:09:39.540 --> 01:09:45.500] where the setup is we have 80 different prompts [01:09:45.500 --> 01:09:47.420] from the Vicuna data set. [01:09:47.420 --> 01:09:50.900] And we give this prompt to two random chatbots. [01:09:50.900 --> 01:09:54.020] And then they compete to generate the best response. [01:09:54.020 --> 01:09:56.100] Each chatbot generates a response. [01:09:56.100 --> 01:10:00.860] And then the responses are judged by the humans or GPT-4. [01:10:00.860 --> 01:10:04.860] And either humans or GPT-4 say which response is better. [01:10:04.860 --> 01:10:07.380] This is a game. [01:10:07.380 --> 01:10:10.460] And so we play multiple games of many random allocations [01:10:10.460 --> 01:10:11.200] of chatbots. [01:10:11.200 --> 01:10:13.500] And with that, we can determine which chatbot is better [01:10:13.500 --> 01:10:14.900] than another chatbot. [01:10:14.900 --> 01:10:17.700] If we do this setup, then we find [01:10:17.700 --> 01:10:21.700] that humans think our chatbot on these Vicuna prompts [01:10:21.700 --> 01:10:24.100] is a little bit better than chatGPT. [01:10:24.100 --> 01:10:27.380] If we ask GPT-4, it says it's about the same quality [01:10:27.380 --> 01:10:29.100] as chatGPT. [01:10:29.100 --> 01:10:32.060] This doesn't mean that our bot is as good as chatGPT. [01:10:32.060 --> 01:10:33.980] But for these particular prompts, [01:10:33.980 --> 01:10:37.420] it is about the same quality. [01:10:37.420 --> 01:10:38.620] On the right is also a demo. [01:10:38.620 --> 01:10:41.620] You can scan it and try our chatbot. [01:10:41.620 --> 01:10:43.860] And that's everything that I have. [01:10:43.860 --> 01:10:47.020] So just to conclude, Killora makes fine-tuning [01:10:47.020 --> 01:10:48.740] 18 times cheaper. [01:10:48.740 --> 01:10:50.780] With the 4-bit normal float, we can [01:10:50.780 --> 01:10:53.500] replicate 16-bit fine-tuning performance. [01:10:53.500 --> 01:10:56.740] And we have also shown that you can create very high-quality [01:10:56.740 --> 01:10:58.380] chatbots with Killora. [01:10:58.380 --> 01:11:01.020] So with all of that, it's very simple [01:11:01.020 --> 01:11:04.900] to now create high-quality fine-tuned models. [01:11:04.900 --> 01:11:07.180] And it's so cheap that everybody has access [01:11:07.180 --> 01:11:10.020] to fine-tuning these large models. [01:11:10.020 --> 01:11:12.580] Killora is available in the Bits and Bytes library. [01:11:12.580 --> 01:11:15.320] And it's also integrated in the Hugging Face Transformer stack. [01:11:15.320 --> 01:11:18.260] And so there, you can very easily use it. [01:11:18.260 --> 01:11:20.100] I'm also on the academic job market. [01:11:20.100 --> 01:11:22.820] So please get in touch if you're interested. [01:11:22.820 --> 01:11:24.880] Later this week, I will also give [01:11:24.880 --> 01:11:27.620] a talk on the making of Killora at the workshop. [01:11:27.620 --> 01:11:30.980] So stay tuned on Twitter for more information about that. [01:11:30.980 --> 01:11:31.940] And that's what I have. [01:11:31.940 --> 01:11:33.340] And I'm happy to take questions. [01:11:33.340 --> 01:11:34.420] Thank you so much. [01:11:34.420 --> 01:11:38.340] [APPLAUSE] [01:11:38.340 --> 01:11:41.100] So we're going to make a bit of a hard pivot [01:11:41.100 --> 01:11:44.300] now from the world of optimization, fine-tuning, [01:11:44.300 --> 01:11:47.620] and training methods into the world of multimodality, which [01:11:47.620 --> 01:11:49.140] is another big theme of this year [01:11:49.140 --> 01:11:51.220] and probably every year to come. [01:11:51.220 --> 01:11:53.700] Every previous paper we've covered on the pod up [01:11:53.700 --> 01:11:55.820] to this point, I've heard of online. [01:11:55.820 --> 01:11:57.220] And it's relatively well-known. [01:11:57.220 --> 01:12:00.400] You didn't actually need to meet the people to hear about them. [01:12:00.400 --> 01:12:02.740] But one of the joys of coming to a conference like NeurIPS [01:12:02.740 --> 01:12:05.220] is finding things that you may not [01:12:05.220 --> 01:12:08.500] have seen just in case of your filter bubble [01:12:08.500 --> 01:12:10.940] or just because there's just way too many things out there. [01:12:10.940 --> 01:12:13.020] And you didn't have the time to look into them. [01:12:13.020 --> 01:12:15.300] And this was definitely true for me for Datacomp, [01:12:15.300 --> 01:12:18.940] which I never heard of, but also a very legitimate effort. [01:12:18.940 --> 01:12:21.140] And I actually had a chat with them after their talk. [01:12:21.140 --> 01:12:23.420] But first, let's introduce what Datacomp is. [01:12:23.420 --> 01:12:24.520] My name is Samir. [01:12:24.520 --> 01:12:27.140] And this is Gabriel Iliarco. [01:12:27.140 --> 01:12:28.300] And this is Alex Fang. [01:12:28.300 --> 01:12:31.180] And today, we're going to be presenting our work, Datacomp, [01:12:31.180 --> 01:12:34.580] In Search of the Next Generation of Multimodal Datasets. [01:12:34.580 --> 01:12:37.380] And this paper was really made possible [01:12:37.380 --> 01:12:39.180] by a whole team of people. [01:12:39.180 --> 01:12:41.060] And so we're very lucky and fortunate to be [01:12:41.060 --> 01:12:46.300] able to share it on behalf of the whole team. [01:12:46.300 --> 01:12:46.780] OK. [01:12:46.780 --> 01:12:49.300] So we want to start with a little bit of a history [01:12:49.300 --> 01:12:53.060] of computer vision models. [01:12:53.060 --> 01:12:55.900] So in this kind of traditional paradigm of image [01:12:55.900 --> 01:12:57.660] classification, what we would do is [01:12:57.660 --> 01:13:00.140] we would create a specialized data set. [01:13:00.140 --> 01:13:02.660] We'll call that a traditional supervised data set [01:13:02.660 --> 01:13:04.860] with certain class labels. [01:13:04.860 --> 01:13:08.580] For example, 10 different labels for the MNIST data set. [01:13:08.580 --> 01:13:10.500] And then we would train these fixed models [01:13:10.500 --> 01:13:12.300] on these kinds of data sets. [01:13:12.300 --> 01:13:13.820] And this was really cool because it [01:13:13.820 --> 01:13:16.700] led to all kinds of architectural improvements. [01:13:16.700 --> 01:13:19.620] You can think ResNets, skip connections, [01:13:19.620 --> 01:13:22.700] applications of attention. [01:13:22.700 --> 01:13:25.580] But when you needed to add an additional task, [01:13:25.580 --> 01:13:29.060] say ImageNet 1K, you had to create a new data [01:13:29.060 --> 01:13:31.660] set with a new set of labels. [01:13:31.660 --> 01:13:36.100] And this was a laborious process. [01:13:36.100 --> 01:13:40.580] But then right around 2021, something really cool happened. [01:13:40.580 --> 01:13:42.740] The paradigm a little bit switched [01:13:42.740 --> 01:13:45.380] to these image text data sets that [01:13:45.380 --> 01:13:48.940] allowed trading these open vocabulary models. [01:13:48.940 --> 01:13:54.540] And suddenly, we could do things like train a unified model that [01:13:54.540 --> 01:13:58.220] could then downstream do arbitrary image classification [01:13:58.220 --> 01:13:59.180] tasks. [01:13:59.180 --> 01:14:02.540] And this is really a sort of data set transition [01:14:02.540 --> 01:14:06.620] is kind of the takeaway here. [01:14:06.620 --> 01:14:10.580] So in spite of this kind of transition between data sets, [01:14:10.580 --> 01:14:13.300] the standard machine learning pipeline [01:14:13.300 --> 01:14:15.460] actually stayed relatively consistent. [01:14:15.460 --> 01:14:17.140] So what we're still going to do is [01:14:17.140 --> 01:14:21.520] create a monolithic artifact, a data set, keep that fixed, [01:14:21.520 --> 01:14:25.100] and then iterate on model training on that data set. [01:14:25.100 --> 01:14:28.060] And this is still a really cool recipe. [01:14:28.060 --> 01:14:32.020] And it's led to progress in downstream evaluations. [01:14:32.020 --> 01:14:34.240] But what we really ask in Data Comp [01:14:34.240 --> 01:14:37.020] and the center of our paper is, how much performance [01:14:37.020 --> 01:14:39.220] are we actually leaving on the table [01:14:39.220 --> 01:14:42.020] by adopting the standard ML pipeline? [01:14:42.020 --> 01:14:45.980] Can we actually improve models by iterating on data sets [01:14:45.980 --> 01:14:49.540] instead of on model architectures? [01:14:49.540 --> 01:14:52.140] And so fundamentally, Data Comp is a benchmark [01:14:52.140 --> 01:14:57.460] for data set development to help the community understand [01:14:57.460 --> 01:15:01.980] how data set decisions improve models. [01:15:01.980 --> 01:15:04.820] So specifically, we're going to look at this CLIP trading [01:15:04.820 --> 01:15:09.020] regime for these more modern image text data sets, which [01:15:09.020 --> 01:15:11.140] are popular nowadays. [01:15:11.140 --> 01:15:13.340] And so we want to give just a brief overview of CLIP [01:15:13.340 --> 01:15:16.140] so that we're all kind of on the same page. [01:15:16.140 --> 01:15:20.300] So we roughly have a text encoder and an image encoder. [01:15:20.300 --> 01:15:23.860] And we're going to train these encoders from scratch, [01:15:23.860 --> 01:15:30.220] contrastively, in order to align image and text representations. [01:15:30.220 --> 01:15:33.800] And then downstream, if we have a new classification task, [01:15:33.800 --> 01:15:36.700] we're going to do things like write sentences, [01:15:36.700 --> 01:15:39.940] a photo of a plane, a photo of a car, et cetera, [01:15:39.940 --> 01:15:43.900] and then query an image feature against all of these text [01:15:43.900 --> 01:15:45.740] features to retrieve our class label. [01:15:45.740 --> 01:15:51.300] So kind of recentering things back to Data Comp now, [01:15:51.300 --> 01:15:53.460] the picture I think that we should all have in mind [01:15:53.460 --> 01:15:56.580] is we're actually going to fix this CLIP bit, which [01:15:56.580 --> 01:15:59.860] is this middle trading diagram. [01:15:59.860 --> 01:16:02.780] And we're going to iterate on the data selection process [01:16:02.780 --> 01:16:06.580] to create new data sets to train our CLIP models. [01:16:06.580 --> 01:16:08.860] And now I'm going to hand it over to Alex. [01:16:12.140 --> 01:16:14.060] So the Data Comp workflow consists [01:16:14.060 --> 01:16:17.420] of five steps, choosing a scale, selecting data, [01:16:17.420 --> 01:16:21.300] training a model, evaluating, and submitting the results. [01:16:21.300 --> 01:16:24.460] And the first step is choosing the scale, which roughly [01:16:24.460 --> 01:16:27.140] reflects the amount of compute used. [01:16:27.140 --> 01:16:31.020] So Data Comp has four scales. [01:16:31.020 --> 01:16:34.060] At the small scale, we train a VIT B32 [01:16:34.060 --> 01:16:36.180] for 12.8 million samples, which is [01:16:36.180 --> 01:16:39.380] equivalent to fine-tuning a model on ImageNet 1K. [01:16:39.380 --> 01:16:41.660] At the medium scale, we train a VIT B32 [01:16:41.660 --> 01:16:44.140] at 128 million samples seen, which [01:16:44.140 --> 01:16:46.020] is equivalent to training a model from scratch [01:16:46.020 --> 01:16:47.820] on ImageNet 1K. [01:16:47.820 --> 01:16:52.500] At large, we train a VIT B16 for 1.28 billion samples seen, [01:16:52.500 --> 01:16:55.340] which is equivalent to training an ImageNet 21K model [01:16:55.340 --> 01:16:56.540] from scratch. [01:16:56.540 --> 01:17:00.660] And at extra large, we train for 12.8 billion samples seen [01:17:00.660 --> 01:17:02.940] on a VIT L14, which is equivalent to training [01:17:02.940 --> 01:17:05.700] an OpenAI CLIP model. [01:17:05.700 --> 01:17:07.980] One key design decision is that there is no constraint [01:17:07.980 --> 01:17:09.340] on data set size. [01:17:09.340 --> 01:17:10.860] We build our scale configurations [01:17:10.860 --> 01:17:13.540] around samples seen, because practically speaking, [01:17:13.540 --> 01:17:16.740] the key constraints are pool size and compute. [01:17:16.740 --> 01:17:20.300] This means each data point in a data set of 6.4 million samples [01:17:20.300 --> 01:17:22.140] at the small scale is seen twice. [01:17:22.140 --> 01:17:27.860] At the chosen scale, participants [01:17:27.860 --> 01:17:29.780] can then use their data selection method [01:17:29.780 --> 01:17:32.860] on either a fixed provided pool of raw data [01:17:32.860 --> 01:17:36.180] or are free to bring in additional data. [01:17:36.180 --> 01:17:38.780] So in the first option, which is the filtering track, [01:17:38.780 --> 01:17:41.220] participants filter from a provided raw pool [01:17:41.220 --> 01:17:44.940] equivalent in size to the sample seen at the chosen scale. [01:17:44.940 --> 01:17:48.380] Our pool, which we call Common Pool, comes from Common Crawl. [01:17:48.380 --> 01:17:50.100] And then we do minimal preprocessing, [01:17:50.100 --> 01:17:52.780] such as near-duplicate checking against evaluation [01:17:52.780 --> 01:17:55.780] and not-safe-for-work filtering. [01:17:55.780 --> 01:17:57.300] Additionally, we provide metadata [01:17:57.300 --> 01:17:59.820] to help with potential filtering approaches. [01:17:59.820 --> 01:18:02.340] This metadata includes original width and height, [01:18:02.340 --> 01:18:06.500] caption, a checksum, CLIP features, CLIP scores, [01:18:06.500 --> 01:18:08.940] and face bounding boxes for automatic blurring [01:18:08.940 --> 01:18:12.260] to help with privacy concerns. [01:18:12.260 --> 01:18:14.940] The second option is the Bring Your Own Data track. [01:18:14.940 --> 01:18:17.820] This allows participants to use additional data sources, [01:18:17.820 --> 01:18:20.140] as well as both edit and generate images and captions [01:18:20.140 --> 01:18:21.700] from Common Pool. [01:18:21.700 --> 01:18:23.500] We hope this track supports participants [01:18:23.500 --> 01:18:26.460] whose creative approaches do not fit neatly into the filtering [01:18:26.460 --> 01:18:28.900] track, while also maintaining fair comparison [01:18:28.900 --> 01:18:32.460] within the filtering track. [01:18:32.460 --> 01:18:35.300] Next, participants use a fixed training procedure [01:18:35.300 --> 01:18:38.740] to train a model on their newly filtered data. [01:18:38.740 --> 01:18:41.140] For training, we adopt fixed training recipes, [01:18:41.140 --> 01:18:43.780] including hyperparameters for CLIP training. [01:18:43.780 --> 01:18:46.980] And this was based on prior experience. [01:18:46.980 --> 01:18:48.860] Notably, Data Comp participants are not [01:18:48.860 --> 01:18:50.500] allowed to modify these parameters, [01:18:50.500 --> 01:18:53.940] therefore focusing investigation on data set selection. [01:18:53.940 --> 01:18:56.300] In the paper, we show that better data sets are largely [01:18:56.300 --> 01:18:58.580] consistent across variations in training recipes. [01:18:58.580 --> 01:19:04.100] Once models are trained, they are evaluated [01:19:04.100 --> 01:19:07.300] using our provided script. [01:19:07.300 --> 01:19:10.300] Our evaluation suite contains 38 downstream tasks, [01:19:10.300 --> 01:19:14.220] which include image net and variance, a subset of VTAB, [01:19:14.220 --> 01:19:17.700] a subset of wild distribution shifts, fairness benchmarks, [01:19:17.700 --> 01:19:19.900] and retrieval benchmarks. [01:19:19.900 --> 01:19:22.500] And the evaluations are done in a zero shot manner [01:19:22.500 --> 01:19:24.060] to remove the need for fine tuning [01:19:24.060 --> 01:19:27.420] on each individual downstream task. [01:19:27.420 --> 01:19:31.460] And the last step of the process is to submit your results. [01:19:31.460 --> 01:19:33.620] We provide an online leaderboard that participants [01:19:33.620 --> 01:19:36.020] can submit to, which we hope promotes participation [01:19:36.020 --> 01:19:37.460] and collaboration. [01:19:37.460 --> 01:19:39.780] We believe that many of these individual data filtering [01:19:39.780 --> 01:19:41.220] approaches should stack. [01:19:41.220 --> 01:19:44.580] And when combined, will lead to better results. [01:19:44.580 --> 01:19:45.960] Next, I'll hand it over to Gabriel [01:19:45.960 --> 01:19:47.820] to talk about baselines and some new results. [01:19:47.820 --> 01:19:53.380] All right, so let's talk about experiments now. [01:19:53.380 --> 01:19:55.060] We study many baselines in our paper, [01:19:55.060 --> 01:19:57.100] but I'll focus on the two most interesting ones [01:19:57.100 --> 01:19:59.700] in the interest of time. [01:19:59.700 --> 01:20:03.820] The first one is what we call clip score filtering. [01:20:03.820 --> 01:20:05.980] The idea behind clip score filtering is simple. [01:20:05.980 --> 01:20:08.700] We use a pre-trained clip model to compute cosine similarity [01:20:08.700 --> 01:20:11.340] scores for all image text pairs in our data set. [01:20:11.340 --> 01:20:13.760] In this plot, you can see a distribution of these scores [01:20:13.760 --> 01:20:15.620] in our data set. [01:20:15.620 --> 01:20:17.580] We then choose a threshold for the similarity, [01:20:17.580 --> 01:20:19.900] for example, corresponding to the top 30% [01:20:19.900 --> 01:20:23.300] scores in our unfiltered pool. [01:20:23.300 --> 01:20:24.900] We then remove all samples that have [01:20:24.900 --> 01:20:26.500] similarity smaller than this threshold, [01:20:26.500 --> 01:20:28.660] keeping only the samples with high score [01:20:28.660 --> 01:20:30.540] as a proxy for discarding all samples [01:20:30.540 --> 01:20:34.060] that we think have low quality. [01:20:34.060 --> 01:20:35.520] Another filtering baseline is what [01:20:35.520 --> 01:20:38.940] we call image-based filtering. [01:20:38.940 --> 01:20:42.140] For image-based filtering, we again use a trained clip model, [01:20:42.140 --> 01:20:45.100] but this time only to extract image features. [01:20:45.100 --> 01:20:47.820] We then cluster these image features [01:20:47.820 --> 01:20:50.980] and find clusters that match images on ImageNet. [01:20:50.980 --> 01:20:54.220] We keep all clusters that are assigned to at least one image. [01:20:54.220 --> 01:20:56.820] We then discard all the other clusters. [01:20:56.820 --> 01:20:59.340] Note that this filtering is purely based on image features, [01:20:59.340 --> 01:21:01.780] and we do not use any labels or captions [01:21:01.780 --> 01:21:05.260] for this filtering strategy. [01:21:05.260 --> 01:21:07.020] Our best performing baseline is built [01:21:07.020 --> 01:21:10.540] by intersecting between the two baselines I just described, [01:21:10.540 --> 01:21:13.620] clip score filtering and image-based filtering. [01:21:13.620 --> 01:21:16.820] When we apply this technique to our larger pool, [01:21:16.820 --> 01:21:20.300] we find a data set with 1.4 billion samples [01:21:20.300 --> 01:21:24.180] that we call DataComp1B. [01:21:24.180 --> 01:21:27.940] So let's see how well this works in practice. [01:21:27.940 --> 01:21:30.380] We conducted over 300 pre-training experiments [01:21:30.380 --> 01:21:33.240] with many different strategies for filtering our pool. [01:21:33.240 --> 01:21:37.260] Our best data set is DataComp1B, a 1.4 billion subset [01:21:37.260 --> 01:21:39.700] of our pool that leads to much higher accuracy [01:21:39.700 --> 01:21:42.580] than existing data sets, including OpenAI's WIT [01:21:42.580 --> 01:21:44.060] and Lion2B. [01:21:44.060 --> 01:21:47.820] This is the first public data set that outperforms OpenAI. [01:21:47.820 --> 01:21:50.300] Also note that all these models are compute-matched, [01:21:50.300 --> 01:21:53.020] so these gains come at no extra cost at training time. [01:21:53.020 --> 01:21:59.260] One key finding from our work is that smaller, more aggressively [01:21:59.260 --> 01:22:02.260] filtered data sets can perform better than larger data sets [01:22:02.260 --> 01:22:04.260] coming from the same pool. [01:22:04.260 --> 01:22:06.580] As you can see on the plot, when we selected [01:22:06.580 --> 01:22:09.380] samples that have the highest cosine similarity according [01:22:09.380 --> 01:22:11.940] to a train clip model, there is a sweet spot [01:22:11.940 --> 01:22:14.020] for the size of the data set that we keep, [01:22:14.020 --> 01:22:16.840] around 30% of the original pool. [01:22:16.840 --> 01:22:18.980] This means that you're better off using a smaller [01:22:18.980 --> 01:22:22.340] subset of the pool instead of using more noisier data. [01:22:22.340 --> 01:22:26.340] Interestingly, this doesn't happen [01:22:26.340 --> 01:22:28.060] when you sample randomly from the pool, [01:22:28.060 --> 01:22:30.180] as you can see from the dotted line. [01:22:30.180 --> 01:22:32.340] So you can get away with smaller data sets, [01:22:32.340 --> 01:22:34.780] but you do need to be a bit more careful in how [01:22:34.780 --> 01:22:35.820] you are selecting samples. [01:22:35.820 --> 01:22:40.060] Another key finding from our experiments [01:22:40.060 --> 01:22:42.580] is that the ranking of different filtering strategies [01:22:42.580 --> 01:22:44.960] is relatively stable across scales, [01:22:44.960 --> 01:22:46.980] as you can see in these scatter plots. [01:22:46.980 --> 01:22:49.500] These plots show how performance on the small scale [01:22:49.500 --> 01:22:52.180] correlates to performance on the medium scale. [01:22:52.180 --> 01:22:54.020] And while it's not a perfect correlation, [01:22:54.020 --> 01:22:56.300] these plots show that there is hope for doing research [01:22:56.300 --> 01:22:58.480] at smaller scales, since there is a good chance [01:22:58.480 --> 01:23:01.540] that findings will generalize to larger scales. [01:23:01.540 --> 01:23:03.100] And in fact, this is exactly how we [01:23:03.100 --> 01:23:05.300] proceeded during our experiments, [01:23:05.300 --> 01:23:08.140] by first testing things out at smaller scales [01:23:08.140 --> 01:23:10.820] and only scaling up the most promising results. [01:23:10.820 --> 01:23:15.180] This saves us a lot of compute during our experiments. [01:23:15.180 --> 01:23:16.940] There's much more in the paper, [01:23:16.940 --> 01:23:18.300] as you can see in these slides. [01:23:18.300 --> 01:23:20.300] And if you're interested, definitely check it out. [01:23:20.300 --> 01:23:22.740] We are very happy to answer any questions [01:23:22.740 --> 01:23:25.900] and talk more about any of these topics in our poster. [01:23:25.900 --> 01:23:30.980] Since we released the paper, [01:23:30.980 --> 01:23:33.980] there's been a lot of activity in Data Comp. [01:23:33.980 --> 01:23:36.060] The fun thing is that our best performing baseline, [01:23:36.060 --> 01:23:38.140] which we thought was pretty decent, [01:23:38.140 --> 01:23:40.660] were blown out of the water by the community since. [01:23:40.660 --> 01:23:43.760] And it's just really nice to see that happening in real time. [01:23:43.760 --> 01:23:47.060] One example is Data Filtering Networks, or DFN for short, [01:23:47.060 --> 01:23:50.220] where the main idea is similar to clip score filtering, [01:23:50.220 --> 01:23:51.900] but with a deeper dive into what makes [01:23:51.900 --> 01:23:54.120] a good model for data filtering. [01:23:54.120 --> 01:23:58.500] And careful data creation has led to [01:23:58.500 --> 01:24:00.460] what now are the best clip models, [01:24:00.460 --> 01:24:03.700] even outside Data Comp, with an impressive 84.4% [01:24:03.700 --> 01:24:06.660] zero-shot accuracy on ImageNet, using a VIT-H14. [01:24:06.660 --> 01:24:10.340] The central takeaway I'd like to leave with you today [01:24:10.340 --> 01:24:12.380] is that careful experimentation with data sets [01:24:12.380 --> 01:24:15.060] can really pay off, and can lead to very large improvements [01:24:15.060 --> 01:24:17.140] in performance on downstream models. [01:24:17.140 --> 01:24:19.440] So instead of blindly scaling models up, [01:24:19.440 --> 01:24:21.460] I think we as a community should start paying [01:24:21.460 --> 01:24:24.220] more attention to how we design data sets. [01:24:24.220 --> 01:24:26.940] Data Comp is designed to facilitate [01:24:26.940 --> 01:24:28.580] research in that direction. [01:24:28.580 --> 01:24:31.120] It's amazing to see what people are already building with it, [01:24:31.120 --> 01:24:33.380] and I'm super excited to see what comes next. [01:24:33.380 --> 01:24:36.300] Finally, I'd like to reiterate that our benchmark [01:24:36.300 --> 01:24:39.100] is designed to encourage everyone to participate, [01:24:39.100 --> 01:24:41.860] even if you only have a couple of GPUs under your desk. [01:24:42.460 --> 01:24:44.660] (audience laughs) [01:24:44.660 --> 01:24:46.860] So if any of this sounds interesting at all to you, [01:24:46.860 --> 01:24:48.420] feel free to check out our resources, [01:24:48.420 --> 01:24:50.500] including our website, code base, and paper. [01:24:50.500 --> 01:24:52.520] Everything we do is fully open source, [01:24:52.520 --> 01:24:55.820] and we hope these resources are useful for the community. [01:24:55.820 --> 01:24:57.380] Thank you very much. [01:24:57.380 --> 01:25:00.380] (audience applauds) [01:25:00.380 --> 01:25:03.180] - So I quite enjoyed that presentation, [01:25:03.180 --> 01:25:05.660] and obviously this being a image-heavy [01:25:05.660 --> 01:25:07.700] and multi-modal type of paper, [01:25:07.700 --> 01:25:09.300] you should probably check out the images [01:25:09.300 --> 01:25:12.420] and the competition at datacomp.ai. [01:25:12.420 --> 01:25:14.380] But I did manage to catch up with them [01:25:14.380 --> 01:25:16.460] at their poster session and ask them more questions. [01:25:16.460 --> 01:25:18.820] It turns out there's some intellectual lineage [01:25:18.820 --> 01:25:20.900] from Lion with Lion 5B, [01:25:20.900 --> 01:25:23.260] and I do think that this has a strong chance [01:25:23.260 --> 01:25:26.180] to become the new ImageNet, so let's give them a listen. [01:25:26.180 --> 01:25:28.900] Oh, fun fact, they were also wearing Data Comp t-shirts. [01:25:28.900 --> 01:25:30.900] Most people, when they present their poster sessions, [01:25:30.900 --> 01:25:33.860] they're in kind of just somewhat semi-formal attire. [01:25:33.860 --> 01:25:35.900] These guys, they make custom t-shirts for their posters, [01:25:35.900 --> 01:25:37.580] so you know how they're serious. [01:25:37.580 --> 01:25:38.740] - My name is Samir. [01:25:38.740 --> 01:25:42.420] I'm a fourth-year PhD student at Columbia. [01:25:42.420 --> 01:25:45.740] I started working on Data Comp, [01:25:45.740 --> 01:25:49.520] like, I guess around November of last year. [01:25:49.520 --> 01:25:52.700] I had collaborated with a lot of the folks [01:25:52.700 --> 01:25:55.980] that are already on the paper on previous projects, [01:25:55.980 --> 01:25:59.100] like Mitchell Wurtzman, Ludwig, Vaishal, [01:25:59.100 --> 01:26:02.100] and they kind of just kind of roped me in. [01:26:02.100 --> 01:26:03.420] They were looking for hands [01:26:03.420 --> 01:26:05.620] to help out with different tasks, [01:26:05.620 --> 01:26:07.700] and then through the course of time, [01:26:07.700 --> 01:26:08.940] my involvement just kind of grew [01:26:08.940 --> 01:26:10.400] 'cause I got really excited about it. [01:26:10.400 --> 01:26:13.220] - Yeah, how did this become such a big effort? [01:26:13.220 --> 01:26:14.540] Like, you guys are wearing t-shirts. [01:26:14.540 --> 01:26:15.740] - Yeah. - This is not normal. [01:26:15.740 --> 01:26:20.500] - Yeah, yeah, yeah, so we really took this project [01:26:20.500 --> 01:26:23.700] very seriously 'cause we wanted the benchmark [01:26:23.700 --> 01:26:25.740] to be really good and thorough, [01:26:25.740 --> 01:26:29.900] and because of that, we were working at kind of a scale [01:26:29.900 --> 01:26:32.420] that was kind of unprecedented for academics. [01:26:32.420 --> 01:26:36.780] We generated the pool of 12.8 billion image-text pairs. [01:26:36.780 --> 01:26:39.020] We wanted evaluations to be very thorough, [01:26:39.020 --> 01:26:40.920] many, many downstream tasks, [01:26:40.920 --> 01:26:45.200] and that just took a lot of people to commit to the project. [01:26:45.200 --> 01:26:47.500] - And how do people find out about something like this? [01:26:47.500 --> 01:26:50.260] Like, is there, you're not from the same university. [01:26:50.260 --> 01:26:51.260] Is there a community somewhere [01:26:51.260 --> 01:26:53.300] that you all just gather and coordinate? [01:26:53.300 --> 01:26:57.540] - Yeah, yes, so Ludwig, who's kind of the last author [01:26:57.540 --> 01:27:02.260] on this paper, is kind of networked all around. [01:27:02.260 --> 01:27:05.260] He's very friendly and very open to collaboration, [01:27:05.260 --> 01:27:07.180] and I think, because of him, [01:27:07.180 --> 01:27:09.960] many people from many different universities, [01:27:09.960 --> 01:27:12.140] corporations, were able to join. [01:27:12.140 --> 01:27:15.260] - Yeah, and this is separate from the Lyon group? [01:27:15.260 --> 01:27:17.260] - Yeah, so Ludwig is affiliated with Lyon. [01:27:17.260 --> 01:27:18.260] - 'Cause I've seen his name around. [01:27:18.260 --> 01:27:21.780] - Yeah, yeah, but most of the people [01:27:21.780 --> 01:27:24.860] are not necessarily part of Lyon, [01:27:24.860 --> 01:27:28.220] but we all kind of know each other and collaborate. [01:27:28.220 --> 01:27:33.060] - I mean, wouldn't it be better to make this Lyon 1.4b? [01:27:33.060 --> 01:27:36.320] - Yeah, so we maybe could have done that. [01:27:36.320 --> 01:27:39.020] - I think, sorry, Lyon 12.8b, right? [01:27:39.020 --> 01:27:41.420] Like, Lyon has 5b? [01:27:41.420 --> 01:27:44.180] - Yeah, Lyon has a 5b and a 2b subset [01:27:44.180 --> 01:27:45.660] that people train on a lot. [01:27:45.660 --> 01:27:47.140] Yeah, we could have done that. [01:27:47.140 --> 01:27:48.980] I think we were thinking about things [01:27:48.980 --> 01:27:51.360] more from the standpoint of a benchmark, [01:27:51.360 --> 01:27:54.340] and that was really our focus. [01:27:54.340 --> 01:27:57.900] While this 1b dataset that came out of the benchmark [01:27:57.900 --> 01:28:01.720] is an artifact, we kind of wanted to place emphasis [01:28:01.720 --> 01:28:02.980] on the benchmark itself, yeah. [01:28:02.980 --> 01:28:04.940] - So, just to comment on that, [01:28:04.940 --> 01:28:09.820] the idea of, initially it was a dataset, [01:28:09.820 --> 01:28:11.740] but then we thought also about benchmarking, [01:28:11.740 --> 01:28:14.580] but then we thought about the real thing is the community. [01:28:14.580 --> 01:28:17.060] So, we thought that the way that you can actually [01:28:17.060 --> 01:28:21.100] build a community is by opening up the tooling. [01:28:21.100 --> 01:28:23.100] So, a lot of the, in dataset curation, [01:28:23.100 --> 01:28:27.180] the problems is not about, usually you work super hard, [01:28:27.180 --> 01:28:28.540] and then at the end of the day, you make a dataset, [01:28:28.540 --> 01:28:30.160] you release a dataset, and you're done. [01:28:30.160 --> 01:28:33.540] But, the tools that you developed to actually [01:28:33.540 --> 01:28:35.420] clean up the dataset, filter the dataset, [01:28:35.420 --> 01:28:37.980] benchmark the dataset, these are often more valuable [01:28:37.980 --> 01:28:39.940] for other people who want to do the same job. [01:28:39.940 --> 01:28:43.520] So, the central idea was to make a community [01:28:43.520 --> 01:28:46.800] and to open source the tools in addition to the dataset, [01:28:46.800 --> 01:28:49.780] and then allow other people to try different tooling methods [01:28:49.780 --> 01:28:52.300] so going this data-centric AI direction. [01:28:52.300 --> 01:28:54.980] So, that was kind of one of the central ideas [01:28:54.980 --> 01:28:55.820] around Datacom. [01:28:55.820 --> 01:28:57.620] So, Datacom is really about building community [01:28:57.620 --> 01:28:59.640] around dataset curation. [01:28:59.640 --> 01:29:02.620] - And this is the first time I've seen [01:29:02.620 --> 01:29:05.820] like clip score filtering applied like this. [01:29:05.820 --> 01:29:08.060] Is there like, and you also mentioned [01:29:08.060 --> 01:29:09.700] at the end of your oral presentation [01:29:09.700 --> 01:29:11.040] that there were other methods. [01:29:11.040 --> 01:29:12.580] Like, what filtering methods are you seeing [01:29:12.580 --> 01:29:13.840] that are working really well? [01:29:13.840 --> 01:29:15.380] - It's a whole community of people [01:29:15.380 --> 01:29:17.580] who are trying a gazillion different tricks. [01:29:17.580 --> 01:29:19.860] And that's the whole point, right? [01:29:19.860 --> 01:29:22.700] One remarkable thing to point out is that [01:29:22.700 --> 01:29:25.180] if you picked a benchmark, you will see performance [01:29:25.180 --> 01:29:26.660] changing across different benchmarks. [01:29:26.660 --> 01:29:31.660] But we'll see surprising correlation of ImageNet zero shot [01:29:31.660 --> 01:29:33.280] to gazillion other benchmarks. [01:29:33.280 --> 01:29:36.280] So, we have 38 benchmarks, and we see that [01:29:36.280 --> 01:29:38.200] if you do well, basically zero shot on ImageNet, [01:29:38.200 --> 01:29:40.080] you're very, very correlated in predicting [01:29:40.080 --> 01:29:42.640] in how good your model is across the board [01:29:42.640 --> 01:29:45.360] for retrieval, for all kinds of very useful things. [01:29:45.360 --> 01:29:47.160] And the community is developing a gazillion type [01:29:47.160 --> 01:29:48.760] of different methods of data curation. [01:29:48.760 --> 01:29:49.960] That's why we have a leaderboard [01:29:49.960 --> 01:29:50.880] and we're building a community. [01:29:50.880 --> 01:29:52.720] This is not like a paper and we're done. [01:29:52.720 --> 01:29:55.580] You could write like 20 projects [01:29:55.580 --> 01:29:57.300] of different data set curation. [01:29:57.300 --> 01:30:00.920] It's more like a platform for data set curation evaluations. [01:30:00.920 --> 01:30:04.180] Do you remember other methods that are doing well? [01:30:04.180 --> 01:30:06.940] - Yeah, yeah, so that's a great overview. [01:30:06.940 --> 01:30:09.500] And yeah, I think specifically people have been looking [01:30:09.500 --> 01:30:12.260] into designing filtering networks. [01:30:12.260 --> 01:30:15.260] So, rather than using Clip to be the filtering network, [01:30:15.260 --> 01:30:19.080] like what are some other data sets that we might train on [01:30:19.080 --> 01:30:21.020] in order to create these filtering networks? [01:30:21.020 --> 01:30:23.460] What are the differences between a good Clip model [01:30:23.460 --> 01:30:25.340] and a good filtering model? [01:30:25.340 --> 01:30:27.100] So, these are all kind of open questions [01:30:27.100 --> 01:30:29.860] that, as Alex was saying, the community will answer [01:30:29.860 --> 01:30:32.760] by trying a bunch of tricks and methods. [01:30:32.760 --> 01:30:34.760] - Yeah, so you can train like new Clip, [01:30:34.760 --> 01:30:38.540] but you can also train new Stable Diffusion from this. [01:30:38.540 --> 01:30:42.140] Which I'm sure Stability is interested in this. [01:30:42.140 --> 01:30:45.860] Unless you're working on your own sort of Diffusion model. [01:30:45.860 --> 01:30:48.580] - Yeah, so the problem is they compute [01:30:48.580 --> 01:30:51.540] to train multiple Stable Diffusions is needed. [01:30:51.540 --> 01:30:53.700] But yeah, we're definitely interested in that direction [01:30:53.700 --> 01:30:56.500] and we're definitely thinking about that. [01:30:56.500 --> 01:30:59.440] But you could basically include quality [01:30:59.440 --> 01:31:02.260] of a Stable Diffusion as a benchmark [01:31:02.260 --> 01:31:05.700] and evaluate how you would select a subset [01:31:05.700 --> 01:31:09.140] of data to improve on that benchmark. [01:31:09.140 --> 01:31:11.220] - You might wanna talk to Luther. [01:31:11.220 --> 01:31:14.020] I've been talking to Stella Biederman from Luther. [01:31:14.020 --> 01:31:16.620] She's around here, she'll come by. [01:31:16.620 --> 01:31:17.920] Cool, any other future directions [01:31:17.920 --> 01:31:19.740] that you're very excited about? [01:31:19.740 --> 01:31:22.660] - Yeah, we're actually really excited [01:31:22.660 --> 01:31:25.480] about just the concept of Data Comp high level. [01:31:25.480 --> 01:31:29.580] So, right now we're pretty excited about NLP [01:31:29.580 --> 01:31:31.860] and what a Data Comp Lite effort [01:31:31.860 --> 01:31:33.780] would look like in that space. [01:31:33.780 --> 01:31:35.500] - You could extend this approach [01:31:35.500 --> 01:31:38.820] to audio, potentially video. [01:31:38.820 --> 01:31:40.740] Although, video's tricky for me [01:31:40.740 --> 01:31:42.920] just 'cause it's so data heavy. [01:31:42.920 --> 01:31:46.600] There's a lot of orders of magnitude [01:31:46.600 --> 01:31:48.320] of different dimensions it could go to. [01:31:48.320 --> 01:31:50.800] So, I don't know what that might look like. [01:31:50.800 --> 01:31:52.560] - Let me tell you about that. [01:31:52.560 --> 01:31:57.560] So, one idea is you can make a Data Comp for MRI images [01:31:57.560 --> 01:32:00.840] or a Data Comp for this, a Data Comp for that. [01:32:00.840 --> 01:32:02.540] What's the idea, what does that mean? [01:32:02.540 --> 01:32:04.840] It means you fix the model. [01:32:04.840 --> 01:32:06.200] So, classical machine learning, [01:32:06.200 --> 01:32:07.720] as was mentioned in the talk, [01:32:07.720 --> 01:32:08.640] classical machine learning says, [01:32:08.640 --> 01:32:11.200] "Here's a data set, ImageNet, build a million models [01:32:11.200 --> 01:32:13.060] "and tell me what's the best one." [01:32:13.060 --> 01:32:16.680] Now, the Data Comp idea flips this on its head. [01:32:16.680 --> 01:32:20.240] It says, "Here is a big pool of data. [01:32:20.240 --> 01:32:21.760] "The model is fixed. [01:32:21.760 --> 01:32:25.760] "You only select a subset of the pool." [01:32:25.760 --> 01:32:27.120] So, the thing you're selecting [01:32:27.120 --> 01:32:29.480] is which images to keep in the pool. [01:32:29.480 --> 01:32:30.760] Then, the model is fixed. [01:32:30.760 --> 01:32:33.240] But, you're training other machine learning models [01:32:33.240 --> 01:32:36.060] to select what to keep. [01:32:36.060 --> 01:32:37.640] And, that's very powerful. [01:32:37.640 --> 01:32:41.160] So, that was, in classical AI, [01:32:41.160 --> 01:32:44.240] if you're doing the data cleaning, the data filtering, [01:32:44.240 --> 01:32:45.480] that's like the shittiest job. [01:32:45.480 --> 01:32:47.520] (laughing) [01:32:47.520 --> 01:32:50.080] But, we're trying to make that a first-class citizen [01:32:50.080 --> 01:32:52.200] and try and tell you that it's worth to do research [01:32:52.200 --> 01:32:54.200] because it's not that you will manually sit down [01:32:54.200 --> 01:32:57.320] and select images from five billion or 13 billion. [01:32:57.320 --> 01:32:59.560] You will be building models that do that. [01:32:59.560 --> 01:33:02.220] So, you can do a Data Comp for X [01:33:02.220 --> 01:33:04.220] and we're seeing that from the community. [01:33:06.620 --> 01:33:08.780] - I'm curious how you became involved with Data Comp. [01:33:08.780 --> 01:33:13.780] - Yeah, so I was, we have this NSF institute. [01:33:13.780 --> 01:33:14.860] It's called IFML, [01:33:14.860 --> 01:33:17.220] the Institute for the Foundations of Machine Learning. [01:33:17.220 --> 01:33:19.860] And, Ludwig is part of our institute. [01:33:19.860 --> 01:33:23.660] So, we were having lunch and we were discussing about, [01:33:23.660 --> 01:33:25.660] how do, we were discussing about Lion, right? [01:33:25.660 --> 01:33:27.860] And, how to make a better Lion. [01:33:27.860 --> 01:33:30.620] And, we said, okay, instead of just making a better Lion, [01:33:30.620 --> 01:33:32.580] which is what we also started with, [01:33:32.580 --> 01:33:36.420] let's make it a community where we open the tools. [01:33:36.420 --> 01:33:38.220] So, everybody can make a better Lion. [01:33:38.220 --> 01:33:39.900] So, that was the central idea, yeah. [01:33:39.900 --> 01:33:41.780] - What happened to the original Lion? [01:33:41.780 --> 01:33:44.860] - Lion is still a great data set that's still public, [01:33:44.860 --> 01:33:47.340] but this is basically building the next generation, yeah. [01:33:47.340 --> 01:33:48.740] - Yeah, yeah, very cool. [01:33:48.740 --> 01:33:49.580] I wish you good luck. [01:33:49.580 --> 01:33:51.240] I think this is really foundational work. [01:33:51.240 --> 01:33:53.580] It's basically the new ImageNet, right? [01:33:53.580 --> 01:33:57.780] That's 10 years after the original AlexNet moment. [01:33:57.780 --> 01:34:00.100] By the way, that second speaker who was not introduced [01:34:00.100 --> 01:34:02.900] was Alex DeMarcus, who is a professor at UT Austin, [01:34:02.900 --> 01:34:04.980] who just jumped in and chatted. [01:34:04.980 --> 01:34:07.660] And, I do find that it's a very charming element of NeurIPS [01:34:07.660 --> 01:34:11.300] is that it's effectively a coming out party/hiring party [01:34:11.300 --> 01:34:13.180] where all the grad students publish their papers. [01:34:13.180 --> 01:34:15.860] They all have sponsors and more senior researchers [01:34:15.860 --> 01:34:19.320] and professors as the secondary or tertiary authors, [01:34:19.320 --> 01:34:20.460] but their name gets first [01:34:20.460 --> 01:34:22.820] because then they get all the credit and the citations. [01:34:22.820 --> 01:34:24.340] And, the people who are more senior [01:34:24.340 --> 01:34:26.060] just kind of stand there and support them. [01:34:26.060 --> 01:34:28.420] And, Alex definitely jumped in and supported them. [01:34:28.420 --> 01:34:30.820] Just like I saw a bunch of other senior authors [01:34:30.820 --> 01:34:31.940] supporting their grad students [01:34:31.940 --> 01:34:34.060] and directing questions to their grad students [01:34:34.060 --> 01:34:35.820] because their reputations are already secure. [01:34:35.820 --> 01:34:36.660] They have jobs. [01:34:36.660 --> 01:34:39.820] They're just here to help their interns and grad students. [01:34:39.820 --> 01:34:41.060] There's a very interesting tension [01:34:41.060 --> 01:34:45.100] between effectively datasets papers and models papers. [01:34:45.100 --> 01:34:47.380] The datasets people think that their work [01:34:47.380 --> 01:34:48.980] is more long lasting, [01:34:48.980 --> 01:34:51.620] and the models people think that datasets work is dumb. [01:34:51.620 --> 01:34:53.700] And, I think you just need both. [01:34:53.700 --> 01:34:57.660] So, that's my awkward transition from Datacomp into Lava, [01:34:57.660 --> 01:35:00.740] which is probably the single most interesting [01:35:00.740 --> 01:35:04.020] visual language model this year. [01:35:04.020 --> 01:35:07.300] As much as people are in love with GPT-4 vision, [01:35:07.300 --> 01:35:08.300] it's not open source, [01:35:08.300 --> 01:35:10.980] and we don't really honestly know very much about it. [01:35:10.980 --> 01:35:13.780] But, Lava is open and trainable [01:35:13.780 --> 01:35:15.820] with a whole bunch of open source models. [01:35:15.820 --> 01:35:17.500] And, together with Datacomp, [01:35:17.500 --> 01:35:19.020] I think Lava and Datacomp together [01:35:19.020 --> 01:35:20.860] will provide some kind of template [01:35:20.860 --> 01:35:23.260] for the next generation of multimodal models to form. [01:35:23.260 --> 01:35:25.220] So, let's check out Lava. [01:35:25.220 --> 01:35:28.740] - I'm Hao Tian, a final year PhD student at UW Madison, [01:35:28.740 --> 01:35:30.860] and I'm on the job market. [01:35:30.860 --> 01:35:34.660] Today, I'm presenting Visual Instruction Tuning, [01:35:34.660 --> 01:35:37.100] a joint work with Chun-Yuan, Qingyang, [01:35:37.100 --> 01:35:38.580] and my advisor, Yang-Jie. [01:35:38.580 --> 01:35:42.260] So, as a background, we, as humans, [01:35:42.260 --> 01:35:44.580] we can see and reason about the visual world, [01:35:44.580 --> 01:35:47.180] express and interact with natural language. [01:35:47.180 --> 01:35:48.820] Doctors read the CT scans [01:35:48.820 --> 01:35:51.700] and explain their findings to their patients. [01:35:51.700 --> 01:35:54.740] Teachers teach students with conversations. [01:35:54.740 --> 01:35:57.540] And, we share our life and findings on social media [01:35:57.540 --> 01:35:59.140] and interact with others. [01:35:59.140 --> 01:36:00.420] It will be great if we can have [01:36:00.420 --> 01:36:02.180] a visual intelligent assistant [01:36:02.180 --> 01:36:04.060] that can reason about the visual world [01:36:04.060 --> 01:36:06.060] and reflect with language. [01:36:06.060 --> 01:36:09.420] The closest to work along this direction [01:36:09.420 --> 01:36:11.220] are image-to-text generation models, [01:36:11.220 --> 01:36:13.460] where the model takes in the image as the input [01:36:13.460 --> 01:36:16.460] and output the text, reflecting its understanding. [01:36:16.460 --> 01:36:19.100] Such models, like JIT, Blip2, and Flamingo, [01:36:19.100 --> 01:36:21.420] has basic visual reasoning capability, [01:36:21.420 --> 01:36:23.260] while they generally lack the ability [01:36:23.260 --> 01:36:25.020] to follow very complex instructions [01:36:25.020 --> 01:36:27.620] or engage in very long conversations. [01:36:27.620 --> 01:36:30.700] Back in March, OpenAI demonstrated GPT-4 vision [01:36:30.700 --> 01:36:33.020] with strong visual reasoning capability. [01:36:33.020 --> 01:36:35.020] For example, given such an image, [01:36:35.020 --> 01:36:38.620] and the user requests, "What's unusual about this image?" [01:36:38.620 --> 01:36:43.260] GPT-4 vision is able to reason beyond just visual facts. [01:36:43.260 --> 01:36:45.540] It's able to figure out that the unusual thing [01:36:45.540 --> 01:36:47.980] is actually the man's ironing clothes [01:36:47.980 --> 01:36:50.660] when standing on the back of a taxi. [01:36:50.660 --> 01:36:54.780] It's great, but it's not accessible until very recently, [01:36:54.780 --> 01:36:56.940] and there's no disclosure on how it works. [01:36:56.940 --> 01:36:59.780] So, if we are able to create an open-source model [01:36:59.780 --> 01:37:03.060] with similar level of visual reasoning capability, [01:37:03.060 --> 01:37:05.180] it will be great, as it allows us [01:37:05.180 --> 01:37:07.580] to have a deeper understanding of how the models behave, [01:37:07.580 --> 01:37:09.220] and we can have a joint effort [01:37:09.220 --> 01:37:11.460] from the whole community to make it better. [01:37:11.460 --> 01:37:13.820] So, as a starting point, [01:37:13.820 --> 01:37:15.740] how can we create such multi-modal models [01:37:15.740 --> 01:37:18.300] that can actually follow humans' intent? [01:37:18.300 --> 01:37:20.620] In NLP, researchers find that instruction tuning [01:37:20.620 --> 01:37:22.820] allows the model to learn to follow the instructions [01:37:22.820 --> 01:37:24.380] by fine-tuning the model on a small set [01:37:24.380 --> 01:37:26.260] of instruction and answer pairs, [01:37:26.260 --> 01:37:28.060] like explaining the human's behavior [01:37:28.060 --> 01:37:30.060] or movie recommendation. [01:37:30.060 --> 01:37:31.940] And creating such instruction [01:37:31.940 --> 01:37:34.140] just by letting human writing it is very costly, [01:37:34.140 --> 01:37:37.540] and Self-Instruct proposed to use teacher models [01:37:37.540 --> 01:37:39.660] like ChatGPT to create such instructions [01:37:39.660 --> 01:37:43.020] by expanding a small set of seed instruction output pairs [01:37:43.020 --> 01:37:45.220] to million-scale using in-context learning, [01:37:45.220 --> 01:37:47.740] and it's affordable, and has been used [01:37:47.740 --> 01:37:49.780] to create open-source language models [01:37:49.780 --> 01:37:52.900] like Alpaca based on the base Lama model. [01:37:52.900 --> 01:37:54.140] So, now the question is, [01:37:54.140 --> 01:37:56.900] how can we create visual instruction following models? [01:37:56.900 --> 01:37:59.380] And let's start with this basic architecture [01:37:59.380 --> 01:38:01.360] where we have an image first, [01:38:01.360 --> 01:38:02.660] we have a visual encoder, [01:38:02.660 --> 01:38:04.820] which can encode it into the visual features, [01:38:04.820 --> 01:38:07.420] a cross-modal connector that can bridge it [01:38:07.420 --> 01:38:09.100] to the language decoder. [01:38:09.100 --> 01:38:11.480] The language decoder also takes in the user instructions [01:38:11.480 --> 01:38:12.940] and perform the reasoning [01:38:12.940 --> 01:38:16.700] and output its understanding using the text. [01:38:16.700 --> 01:38:19.300] So the key is, how do we train this model [01:38:19.300 --> 01:38:20.940] for following multi-modal instructions, [01:38:20.940 --> 01:38:22.860] and how do we obtain such data? [01:38:22.860 --> 01:38:25.460] The straightforward way would be to use a Self-Instruct, [01:38:25.460 --> 01:38:29.940] and let's find a multi-modal teacher and let it expand. [01:38:29.940 --> 01:38:33.420] However, if we take a look at those existing teachers' models [01:38:33.420 --> 01:38:35.380] that were used, they are all text-only, [01:38:35.380 --> 01:38:38.420] and there were no powerful multi-modal teachers. [01:38:38.420 --> 01:38:42.060] And in our paper, we propose to leverage a text-only GPT, [01:38:42.060 --> 01:38:46.180] and we provide image context in the textual format to GPT [01:38:46.180 --> 01:38:48.040] so that it can understand. [01:38:48.040 --> 01:38:50.460] For example, here, we have an image, [01:38:50.460 --> 01:38:53.300] and we can use the COCO annotations captions [01:38:53.300 --> 01:38:55.100] so we can have an image-level context [01:38:55.100 --> 01:38:57.460] which describes what's happening in the image. [01:38:57.460 --> 01:38:59.000] We can also have the bounding box [01:38:59.000 --> 01:39:01.400] and object category annotations from the COCO [01:39:01.400 --> 01:39:03.960] so that we are able to get region-level context, [01:39:03.960 --> 01:39:05.500] which provides even more details [01:39:05.500 --> 01:39:08.260] that may not be captured in the captions. [01:39:08.260 --> 01:39:12.840] So let's take a closer look on our text-only data engine. [01:39:12.840 --> 01:39:15.700] We will have two parts of the inputs. [01:39:15.700 --> 01:39:17.540] First are the in-context examples, [01:39:17.540 --> 01:39:20.900] which are the exemplars that we guide the ChatGPT [01:39:20.900 --> 01:39:23.580] on how they should generate the visual instructions. [01:39:23.580 --> 01:39:25.460] So we'll have example image. [01:39:25.460 --> 01:39:28.300] We convert them into the image context in the textual format [01:39:28.300 --> 01:39:29.900] that we just described. [01:39:29.900 --> 01:39:31.900] We write the instruction and answers [01:39:31.900 --> 01:39:34.780] about those visual content in the image. [01:39:34.780 --> 01:39:38.100] These are the examples for ChatGPT to learn from. [01:39:38.100 --> 01:39:40.020] And then we do the actual inference, [01:39:40.020 --> 01:39:42.540] and for any image in the COCO training data set, [01:39:42.540 --> 01:39:44.680] we're able to convert them into textual format [01:39:44.680 --> 01:39:46.280] using the COCO annotation. [01:39:46.280 --> 01:39:48.260] And ChatGPT will just learn to generate [01:39:48.260 --> 01:39:52.740] those instructions and answers about those image contexts. [01:39:52.740 --> 01:39:55.340] We gather the instructions, answers, and also the image [01:39:55.340 --> 01:39:57.540] to create our visual instruction-following data, [01:39:57.540 --> 01:40:02.740] which is a triplet of image, instruction, and answer. [01:40:02.740 --> 01:40:05.020] To better facilitate learning, we [01:40:05.020 --> 01:40:07.660] create three types of responses. [01:40:07.660 --> 01:40:10.000] First is a conversation to facilitate [01:40:10.000 --> 01:40:12.300] multi-turn engagement, detailed description [01:40:12.300 --> 01:40:14.620] to train the model to focus on visual details, [01:40:14.620 --> 01:40:17.440] and complex reasoning to allow the model to focus [01:40:17.440 --> 01:40:19.180] beyond the visual facts. [01:40:19.180 --> 01:40:22.300] For example, question, what challenges do people face? [01:40:22.300 --> 01:40:24.620] The model not only needs to figure out [01:40:24.620 --> 01:40:26.660] like there are back luggages, there are bags, [01:40:26.660 --> 01:40:27.700] there are SUVs. [01:40:27.700 --> 01:40:29.660] It also needs to figure out that the challenges, [01:40:29.660 --> 01:40:31.740] as they may not be able to fit all the luggage [01:40:31.740 --> 01:40:33.860] on the back of the SUV. [01:40:33.860 --> 01:40:37.780] So we create a LavaInstruct 158k and train Lava. [01:40:37.780 --> 01:40:40.740] It's a model composed of these three simple components. [01:40:40.740 --> 01:40:44.260] We use Clip as a vision encoder, InstructionTuneLanguageModel [01:40:44.260 --> 01:40:46.580] as a Vicuna as a language decoder, [01:40:46.580 --> 01:40:49.700] and we use linear layer for the projection. [01:40:49.700 --> 01:40:51.940] And we find it work quite well because the cloud visual [01:40:51.940 --> 01:40:53.860] features already carry great semantics, [01:40:53.860 --> 01:40:56.580] and a single linear layer is sufficient to project it [01:40:56.580 --> 01:41:00.580] into a space where the language decoder can understand well. [01:41:00.580 --> 01:41:03.140] For model training, we use a two-stage training pipeline, [01:41:03.140 --> 01:41:05.780] where in the first stage, we pre-train the projector only [01:41:05.780 --> 01:41:07.540] for the feature alignment so that it's [01:41:07.540 --> 01:41:09.920] projected into a proper space. [01:41:09.920 --> 01:41:12.620] And in stage two, we perform end-to-end visual instruction [01:41:12.620 --> 01:41:15.780] tuning on the generated visual instruction following data set. [01:41:15.780 --> 01:41:16.660] We train a projector. [01:41:16.660 --> 01:41:17.940] We train the language model. [01:41:17.940 --> 01:41:21.860] And if you have limited compute, you [01:41:21.860 --> 01:41:25.100] can try it LoRa or QLoRa, or even just the projector only. [01:41:25.100 --> 01:41:29.020] It can give you decent visual chat performance. [01:41:29.020 --> 01:41:32.300] After we train Lava, we found several interesting emerging [01:41:32.300 --> 01:41:33.300] properties. [01:41:33.300 --> 01:41:36.420] Let's quickly revisit some of the data properties first. [01:41:36.420 --> 01:41:38.900] So our visual instructions are in English only, [01:41:38.900 --> 01:41:40.900] no human name annotation, and there's [01:41:40.900 --> 01:41:42.940] no explicit OCR data. [01:41:42.940 --> 01:41:45.700] Lava can have strong visual training capability, [01:41:45.700 --> 01:41:48.020] as GPT for Vision does, where we are able to figure out [01:41:48.020 --> 01:41:50.460] the unusualness is actually the man's ironing clothes [01:41:50.460 --> 01:41:51.780] on the back of a minivan. [01:41:51.780 --> 01:41:55.100] And it's more visually grounded than open-source baselines, [01:41:55.100 --> 01:41:57.860] like Blip2 and OpenFlamingo. [01:41:57.860 --> 01:42:00.620] It understands this humorously parodied Mona Lisa [01:42:00.620 --> 01:42:02.300] with a dog in the same pose. [01:42:02.300 --> 01:42:05.660] And it's definitely out of distribution. [01:42:05.660 --> 01:42:09.340] It also has a strong emerging OCR capability, [01:42:09.340 --> 01:42:13.020] where it can recognize NeurIPS 2023 from this presentation [01:42:13.020 --> 01:42:13.740] slide. [01:42:13.740 --> 01:42:17.380] And it correlates with the pre-trained language knowledge [01:42:17.380 --> 01:42:19.460] when asked about who will be interested in this. [01:42:19.460 --> 01:42:21.940] It will relate and say it's related [01:42:21.940 --> 01:42:24.740] to artificial intelligence and machine learning. [01:42:24.740 --> 01:42:27.460] Although our visual instructions are just in English, [01:42:27.460 --> 01:42:30.180] it's able to perform the reasoning and output [01:42:30.180 --> 01:42:33.020] text in Chinese and other foreign languages, [01:42:33.020 --> 01:42:34.720] like it recognizes the French quarter [01:42:34.720 --> 01:42:38.740] and performs a brief description in Chinese here. [01:42:38.740 --> 01:42:42.980] So how can we evaluate our large multimodal models? [01:42:42.980 --> 01:42:46.500] We draw our inspiration from NLP and slightly modify [01:42:46.500 --> 01:42:49.220] our data creation pipeline to use a text-only GPT [01:42:49.220 --> 01:42:51.500] to do the evaluation, where we have the image [01:42:51.500 --> 01:42:53.300] context in a textual format. [01:42:53.300 --> 01:42:55.120] We have the user instruction. [01:42:55.120 --> 01:42:57.020] We have two model outputs. [01:42:57.020 --> 01:42:59.740] And we just feed all of them into the text-only GPT [01:42:59.740 --> 01:43:01.380] and request for the feedback. [01:43:01.380 --> 01:43:04.660] It will give you a score out of 10 for each of the assistant [01:43:04.660 --> 01:43:06.660] and also provide you an explanation [01:43:06.660 --> 01:43:10.780] so that you can understand how the model is behaved. [01:43:10.780 --> 01:43:12.300] We create a challenging benchmark, [01:43:12.300 --> 01:43:14.980] Lava Bench in the Wild, which requires knowledge [01:43:14.980 --> 01:43:17.220] beyond training data, multilingual understanding, [01:43:17.220 --> 01:43:19.540] and also perception of subtle details. [01:43:19.540 --> 01:43:21.980] We create a very detailed textual annotation [01:43:21.980 --> 01:43:24.980] of the image context in those images. [01:43:24.980 --> 01:43:29.460] And we can feed them for a GPT evaluation. [01:43:29.460 --> 01:43:31.980] And it's not only just for accuracy, [01:43:31.980 --> 01:43:35.060] but also for hallucination. [01:43:35.060 --> 01:43:36.740] Since the introduction of Lava, there [01:43:36.740 --> 01:43:39.060] has been great effort from the community, [01:43:39.060 --> 01:43:41.980] ranging from data, model, modality, [01:43:41.980 --> 01:43:44.180] and expanding to different tasks, [01:43:44.180 --> 01:43:45.980] as well as developing benchmarks for us [01:43:45.980 --> 01:43:47.820] to better understand the model. [01:43:47.820 --> 01:43:50.060] The development after June are just too many [01:43:50.060 --> 01:43:51.740] to fit into the slide. [01:43:51.740 --> 01:43:54.380] And we, as Lava team, has also pushing the effort [01:43:54.380 --> 01:43:56.340] to make it more accessible and expanding [01:43:56.340 --> 01:43:59.940] its capability in terms of RLHF tool use, [01:43:59.940 --> 01:44:03.300] as well as visual prompting. [01:44:03.300 --> 01:44:06.460] Our improved version of Lava 1.5, [01:44:06.460 --> 01:44:10.420] with just simple modifications to the data and model, [01:44:10.420 --> 01:44:12.740] we show that it achieves great performance [01:44:12.740 --> 01:44:15.340] on a range of 12 benchmarks. [01:44:15.340 --> 01:44:17.060] And it's sample efficient that it only [01:44:17.060 --> 01:44:20.620] requires less than 1% of the data that other approach use. [01:44:20.620 --> 01:44:23.660] We're able to train Lava 1.5 within one day [01:44:23.660 --> 01:44:25.100] on a single node. [01:44:25.100 --> 01:44:30.260] Check out our workshop poster on Friday. [01:44:30.260 --> 01:44:33.700] So in conclusion, Lava can reason about the visual world, [01:44:33.700 --> 01:44:35.420] reflect with natural language. [01:44:35.420 --> 01:44:37.860] Its design is simple and general that we [01:44:37.860 --> 01:44:40.500] show that it is possible to adapt the language [01:44:40.500 --> 01:44:43.900] model to multi-model effectively and efficiently, [01:44:43.900 --> 01:44:47.660] that we can train it within one day on a single node. [01:44:47.660 --> 01:44:50.420] Because its design is so simple that we are compatible [01:44:50.420 --> 01:44:52.860] with almost all the optimizations that [01:44:52.860 --> 01:44:55.280] are designed for language models for both training [01:44:55.280 --> 01:44:59.380] and deployment, and it's fully open source. [01:44:59.380 --> 01:45:02.860] And unfortunately, due to the Wi-Fi network issue, [01:45:02.860 --> 01:45:05.060] we are unable to do a live demo here. [01:45:05.060 --> 01:45:08.700] But Lava is able to run on MacBook Air. [01:45:08.700 --> 01:45:11.100] So I will still do a live demo here. [01:45:11.100 --> 01:45:12.460] And let's try to see. [01:45:12.460 --> 01:45:16.100] This is an image that I took yesterday here. [01:45:16.100 --> 01:45:21.980] And I will just say, what's this event, and is it popular? [01:45:21.980 --> 01:45:24.100] It's the tree of salt presentation. [01:45:24.100 --> 01:45:25.140] It's really popular. [01:45:25.140 --> 01:45:28.700] I can just barely stand in the back. [01:45:28.700 --> 01:45:30.020] So it will just run for a while. [01:45:30.020 --> 01:45:31.660] And it says, the event seems to be [01:45:31.660 --> 01:45:33.740] a conference or presentation. [01:45:33.740 --> 01:45:35.500] There's a large number of people attending. [01:45:35.500 --> 01:45:39.660] And it appears to be popular, because lots of them [01:45:39.660 --> 01:45:41.300] are standing, including me. [01:45:41.300 --> 01:45:45.380] So I will say, OK, I attended as well. [01:45:45.380 --> 01:45:49.100] It is NeurIPS 2023. [01:45:49.100 --> 01:45:53.260] And the experience is great. [01:45:53.260 --> 01:45:56.700] Help me draft a tweet. [01:45:56.700 --> 01:45:58.660] And it will think for a while. [01:45:58.660 --> 01:46:01.980] I hope it will be optimized further [01:46:01.980 --> 01:46:04.100] for a faster filling stage. [01:46:04.100 --> 01:46:08.020] And just attended NeurIPS 2023, amazing experience, [01:46:08.020 --> 01:46:10.740] packed with knowledgeable speakers and attendees, [01:46:10.740 --> 01:46:13.580] learned so much, and made valuable connections. [01:46:13.580 --> 01:46:14.420] True. [01:46:14.420 --> 01:46:17.460] And highly recommend for anyone interested in AI-related fields [01:46:17.460 --> 01:46:18.940] and some hashtags. [01:46:18.940 --> 01:46:20.020] I hope you like it. [01:46:20.020 --> 01:46:21.220] And thank you so much. [01:46:21.220 --> 01:46:25.260] Please come to our poster session at number 229, [01:46:25.260 --> 01:46:30.740] our demo code, data, model, everything is open source. [01:46:30.740 --> 01:46:33.820] We are so excited to talk with you more about Lava. [01:46:33.820 --> 01:46:35.100] And thank you all for coming. [01:46:35.100 --> 01:46:40.540] [APPLAUSE] [01:46:40.540 --> 01:46:42.860] If DataComp is an example of what a really good benchmark [01:46:42.860 --> 01:46:45.100] and data set paper looks like, then I [01:46:45.100 --> 01:46:49.060] think Lava is an example of what really good kind of state [01:46:49.060 --> 01:46:51.420] of the art research on visual instruction [01:46:51.420 --> 01:46:53.660] tuning and visual language models looks like. [01:46:53.660 --> 01:46:55.300] It definitely has inspired a bunch [01:46:55.300 --> 01:46:59.060] of copycats and derivative work in the open source model [01:46:59.060 --> 01:47:00.500] space, notably Baklava. [01:47:00.500 --> 01:47:02.860] And I think there's just going to be a lot more work being [01:47:02.860 --> 01:47:03.500] done here. [01:47:03.500 --> 01:47:06.300] We're just realizing that we can plug and play these models [01:47:06.300 --> 01:47:08.580] and train them together in all sorts of ways. [01:47:08.580 --> 01:47:11.500] And Lava is definitely one of the more innovative solutions [01:47:11.500 --> 01:47:14.540] of that that also just solves simultaneously [01:47:14.540 --> 01:47:17.500] a whole bunch of issues with visual understanding. [01:47:17.500 --> 01:47:21.180] Here's the poster session Q&A with Hao Tian. [01:47:21.180 --> 01:47:26.060] Basically, we are trying to create a simple architecture, [01:47:26.060 --> 01:47:27.460] as simple as possible. [01:47:27.460 --> 01:47:31.140] So we have a vision encoder just to encode those features, [01:47:31.140 --> 01:47:33.660] a language model to perform the reasoning, [01:47:33.660 --> 01:47:36.660] and we use a projection layer, which is a linear layer. [01:47:36.660 --> 01:47:40.220] We find it doing pretty well to project the visual features [01:47:40.220 --> 01:47:42.500] to a latent space that the language decoder can [01:47:42.500 --> 01:47:43.400] understand. [01:47:43.400 --> 01:47:47.140] And we believe this is because the visual features of the clip [01:47:47.140 --> 01:47:50.700] already carry great semantics, are in a good latent space. [01:47:50.700 --> 01:47:53.940] So a single linear is sufficient for it to understand. [01:47:53.940 --> 01:47:55.780] Is the language model GPT-4? [01:47:55.780 --> 01:48:00.460] The language model is something open source, not GPT-4. [01:48:00.460 --> 01:48:02.260] You can take that off the shelf, but you're [01:48:02.260 --> 01:48:03.900] training the linear layer. [01:48:03.900 --> 01:48:06.460] So it will be two stage. [01:48:06.460 --> 01:48:09.140] In the first stage, we want to train the language model [01:48:09.140 --> 01:48:10.940] to understand those images. [01:48:10.940 --> 01:48:13.100] So we train the projection layer only. [01:48:13.100 --> 01:48:14.460] And this is our stage one. [01:48:14.460 --> 01:48:17.180] The language model and the vision encoder are frozen. [01:48:17.180 --> 01:48:20.220] And in the stage two, we will train the model [01:48:20.220 --> 01:48:22.220] to follow those instructions. [01:48:22.220 --> 01:48:24.940] So we train the language model and the projector. [01:48:24.940 --> 01:48:27.600] To my knowledge, this is the first work that is adding [01:48:27.600 --> 01:48:31.060] the bounding boxes with the captions. [01:48:31.060 --> 01:48:33.020] Any difficulty in having the language model [01:48:33.020 --> 01:48:34.140] understand all those things? [01:48:34.140 --> 01:48:37.500] Our model does not need to understand bounding box, [01:48:37.500 --> 01:48:41.700] because what we provide to train our model [01:48:41.700 --> 01:48:44.380] is this, visual instruction bounding data. [01:48:44.380 --> 01:48:46.380] The model just needs to understand the image [01:48:46.380 --> 01:48:49.460] and give a proper answer when you give a user's instruction. [01:48:49.460 --> 01:48:53.060] So this is not something our model needs to worry about, [01:48:53.060 --> 01:48:54.540] although we do find that model is [01:48:54.540 --> 01:48:57.660] able to understand those bounding boxes well. [01:48:57.660 --> 01:49:01.140] And key point is that does GPT-4 understand those well? [01:49:01.140 --> 01:49:04.220] And does a text-only GPT-4 understand that well? [01:49:04.220 --> 01:49:07.060] We find it to be true, because what we did [01:49:07.060 --> 01:49:09.860] is that I also work on some image generation model. [01:49:09.860 --> 01:49:13.460] And we have a work on that we can control the image layout [01:49:13.460 --> 01:49:16.620] by just providing some bounding boxes. [01:49:16.620 --> 01:49:19.580] What we did is that we give GPT-4 a caption, [01:49:19.580 --> 01:49:23.300] and we say, can you generate a reasonable layout for me? [01:49:23.300 --> 01:49:25.140] And it's able to do that pretty well. [01:49:25.140 --> 01:49:27.300] So we believe that-- we do not quantitatively [01:49:27.300 --> 01:49:30.420] evaluate how it's good at doing that, [01:49:30.420 --> 01:49:34.160] but it does understand those layout pretty well. [01:49:34.160 --> 01:49:36.020] And also, it can be used to-- [01:49:36.020 --> 01:49:37.820] and also, from the instruction it generated, [01:49:37.820 --> 01:49:41.300] it does know which is on the left, which is on the right. [01:49:41.300 --> 01:49:41.800] Yeah. [01:49:41.800 --> 01:49:44.380] Did you have to qualitatively evaluate [01:49:44.380 --> 01:49:47.460] the output of the answers that GPT-4 gave you? [01:49:47.460 --> 01:49:48.380] Yeah. [01:49:48.380 --> 01:49:51.020] Actually, we do not quantitatively evaluate, [01:49:51.020 --> 01:49:54.780] but we did manually go through some of them [01:49:54.780 --> 01:49:58.340] when we are developing those data engine, [01:49:58.340 --> 01:50:02.460] because we do have some factors to consider in this. [01:50:02.460 --> 01:50:06.820] So we can change the number of in-context examples. [01:50:06.820 --> 01:50:10.260] We can change the way we write those reference instructions [01:50:10.260 --> 01:50:11.180] and answers. [01:50:11.180 --> 01:50:15.060] We can also change the actual instructions [01:50:15.060 --> 01:50:19.580] we use to teach GPT on what is the task. [01:50:19.580 --> 01:50:23.660] So we did qualitatively iterate on how [01:50:23.660 --> 01:50:25.460] we design those data engine. [01:50:25.460 --> 01:50:27.740] And we find it-- [01:50:27.740 --> 01:50:30.460] this process really is quite rewarding, [01:50:30.460 --> 01:50:33.900] because we do, in this process, understand how GPT thinks [01:50:33.900 --> 01:50:36.500] and what are the information that we [01:50:36.500 --> 01:50:38.740] do need to provide GPT-4. [01:50:38.740 --> 01:50:39.240] Yeah. [01:50:39.240 --> 01:50:39.740] I see. [01:50:39.740 --> 01:50:43.580] And then all the bounding boxes that you provide, [01:50:43.580 --> 01:50:45.620] these are kind of ground truth, because you [01:50:45.620 --> 01:50:47.100] get them from code, right? [01:50:47.100 --> 01:50:47.780] That's correct. [01:50:47.780 --> 01:50:48.900] Right, right, right, right. [01:50:48.900 --> 01:50:53.320] So this actually ensures that those contexts are perfect, [01:50:53.320 --> 01:50:55.420] if the human annotators are perfect. [01:50:55.420 --> 01:50:57.660] And the generated instruction answers [01:50:57.660 --> 01:50:59.980] are as good as possible. [01:50:59.980 --> 01:51:02.020] I just want to ask about the training part. [01:51:02.020 --> 01:51:06.400] So if I were to take Lava 1.5 and fine-tune it, [01:51:06.400 --> 01:51:09.060] either full fine-tune or Laura or whatever, [01:51:09.060 --> 01:51:14.060] would you recommend also retraining the projector? [01:51:14.060 --> 01:51:15.820] I guess it depends on your task. [01:51:15.820 --> 01:51:19.100] Are you considering a different domain or-- [01:51:19.100 --> 01:51:22.080] So I want to build off the Lava+ stuff. [01:51:22.080 --> 01:51:24.980] So I know that goes into using tools. [01:51:24.980 --> 01:51:27.380] So it's a little out of scope for this project, [01:51:27.380 --> 01:51:29.780] but maybe for both. [01:51:29.780 --> 01:51:32.340] So let's say I want to take a different multimodal [01:51:32.340 --> 01:51:34.300] instruction following data set. [01:51:34.300 --> 01:51:38.180] For that part, would you recommend retraining-- [01:51:38.180 --> 01:51:39.380] Yeah, that's a good question. [01:51:39.380 --> 01:51:43.180] So I would say that if you want to-- [01:51:43.180 --> 01:51:45.420] if the domain, like the image domain [01:51:45.420 --> 01:51:47.220] that you're going to work on-- [01:51:47.220 --> 01:51:49.660] yeah, medical image, if it is too different, [01:51:49.660 --> 01:51:54.460] then I would recommend actually go with a different stage one [01:51:54.460 --> 01:51:55.180] training. [01:51:55.180 --> 01:51:57.340] Or even just do everything from scratch [01:51:57.340 --> 01:52:01.080] that you have a biomedical clip, right? [01:52:01.080 --> 01:52:04.260] Because that may give you even more benefit. [01:52:04.260 --> 01:52:08.820] But we do observe that if you pre-train [01:52:08.820 --> 01:52:11.180] with Lava's instructions, you pre-train [01:52:11.180 --> 01:52:15.900] with those visual information, it [01:52:15.900 --> 01:52:19.100] learns to do some reasoning about the visual contents. [01:52:19.100 --> 01:52:24.660] And it may be crucial for the visual understanding [01:52:24.660 --> 01:52:26.420] on different other domains. [01:52:26.420 --> 01:52:29.340] So I guess there will be a trade-off. [01:52:29.340 --> 01:52:33.620] And I guess there will be both pros and cons for training [01:52:33.620 --> 01:52:36.020] from another domain from scratch. [01:52:36.020 --> 01:52:37.980] Because you may lose the benefit that you [01:52:37.980 --> 01:52:42.180] get when pre-training on Lava on how to localize those objects. [01:52:42.180 --> 01:52:47.980] So I guess you would need some more experimental evidence [01:52:47.980 --> 01:52:49.580] on making the proper decision. [01:52:49.580 --> 01:52:51.900] So is it fair to say that unless the domain is [01:52:51.900 --> 01:52:56.020] super different, like x-rays, maybe it's fine to just-- [01:52:56.020 --> 01:52:57.620] Yeah, I think it's totally fine. [01:52:57.620 --> 01:53:00.540] And I guess it's better to use the instruction-tuned version, [01:53:00.540 --> 01:53:03.420] because it has so many vision knowledge injected into it. [01:53:03.420 --> 01:53:03.920] OK. [01:53:03.920 --> 01:53:05.420] And then, sorry, one last question. [01:53:05.420 --> 01:53:05.920] Of course. [01:53:05.920 --> 01:53:09.340] So for stage 2, let's say I want to fine-tune on my own thing, [01:53:09.340 --> 01:53:14.500] is the roughly 160k number of examples a good target to hit? [01:53:14.500 --> 01:53:16.780] Do you have recommendations around how big [01:53:16.780 --> 01:53:18.780] that data set should be? [01:53:18.780 --> 01:53:23.540] I guess it also depends on how different the task is, [01:53:23.540 --> 01:53:27.420] and also how bad the model is performing on that task. [01:53:27.420 --> 01:53:31.940] Because I can give a brief example on one [01:53:31.940 --> 01:53:33.660] of the experiments we have done. [01:53:33.660 --> 01:53:38.100] So there's a task that we can train the model to generate [01:53:38.100 --> 01:53:39.780] stable diffusion prompts, for example. [01:53:39.780 --> 01:53:44.420] Basically, it's captured in some style we want. [01:53:44.420 --> 01:53:48.660] And because the Lava is already able to understand [01:53:48.660 --> 01:53:52.660] those visual attributes, the content very well, [01:53:52.660 --> 01:53:56.180] it's just a form of reorganizing the style it responds. [01:53:56.180 --> 01:53:59.500] So we find even 100 examples is sufficient. [01:53:59.500 --> 01:54:00.100] 100? [01:54:00.100 --> 01:54:01.460] Yeah, 100. [01:54:01.460 --> 01:54:06.460] And we just used 100 examples, and it does the work decently. [01:54:06.460 --> 01:54:07.820] Yeah, the task is really easy. [01:54:07.820 --> 01:54:10.020] Yeah, it's just a form of changing the style. [01:54:10.020 --> 01:54:13.860] But if you're trying to do some very different reasoning [01:54:13.860 --> 01:54:17.460] tasks that Lava is not good at, I guess you may need more-- [01:54:17.460 --> 01:54:19.680] Also, I think 10k, generally, is enough. [01:54:19.680 --> 01:54:20.180] What? [01:54:20.180 --> 01:54:21.700] 10k, generally, is enough. [01:54:21.700 --> 01:54:22.460] Yeah, yeah, yeah. [01:54:22.460 --> 01:54:25.660] I guess 10k, or if you want to make it safe, [01:54:25.660 --> 01:54:30.260] I guess maybe 50k is at most. [01:54:30.260 --> 01:54:30.760] Yeah. [01:54:30.760 --> 01:54:32.060] Can I ask a question? [01:54:32.060 --> 01:54:32.740] Yeah, of course. [01:54:32.740 --> 01:54:33.500] I just want to understand-- [01:54:33.500 --> 01:54:34.060] Thank you so much. [01:54:34.060 --> 01:54:34.780] Yeah, thank you. [01:54:34.780 --> 01:54:36.980] Understand how important this vision encoder is. [01:54:36.980 --> 01:54:39.500] Have you ever tried to remove the encoder entirely [01:54:39.500 --> 01:54:44.060] and use the bounding box here as the input of whatever [01:54:44.060 --> 01:54:47.620] language model you are using and just do the same task? [01:54:47.620 --> 01:54:51.780] So I guess the key point here is that if you [01:54:51.780 --> 01:54:54.820] want to remove the encoder completely [01:54:54.820 --> 01:54:57.660] and just use the bounding boxes as the input, [01:54:57.660 --> 01:54:59.120] there will be one question like, how [01:54:59.120 --> 01:55:00.380] are you going to get those bounding boxes? [01:55:00.380 --> 01:55:05.040] And second is that, what if the user asks you about the text? [01:55:05.040 --> 01:55:07.320] Like, are you going to also have an OCR engine? [01:55:07.320 --> 01:55:10.520] And what if the user asks about something else, [01:55:10.520 --> 01:55:12.440] for example, like the attribute? [01:55:12.440 --> 01:55:17.760] And if you think of this, having an end-to-end model [01:55:17.760 --> 01:55:21.440] will make it much more easier and much more generalizable [01:55:21.440 --> 01:55:23.480] to extend to different types of the inputs [01:55:23.480 --> 01:55:26.040] and the user's instructions. [01:55:26.040 --> 01:55:29.240] And also, because now you need some other model [01:55:29.240 --> 01:55:32.200] to generate those bounding boxes, text, [01:55:32.200 --> 01:55:35.840] all of those things, I feel that it's [01:55:35.840 --> 01:55:39.320] good if we can have those models to enhance the capability. [01:55:39.320 --> 01:55:43.240] But you do have a model that are really trained with vision [01:55:43.240 --> 01:55:45.960] and really understand what's happening in this image. [01:55:45.960 --> 01:55:48.840] It can better coordinate those information. [01:55:48.840 --> 01:55:49.340] Yeah. [01:55:49.340 --> 01:55:52.680] So have you ever unfreezed this vision encoder? [01:55:52.680 --> 01:55:53.180] Yes. [01:55:53.180 --> 01:55:54.840] Meaning, during first or second session? [01:55:54.840 --> 01:55:55.360] Yes, yes. [01:55:55.360 --> 01:55:57.520] We have tried to unfreeze the vision encoder. [01:55:57.520 --> 01:55:59.720] And we find it quite useful for some of the text, [01:55:59.720 --> 01:56:00.880] but not for the other. [01:56:00.880 --> 01:56:03.620] So specifically, if it's just asking about, [01:56:03.620 --> 01:56:04.800] what's the attribute? [01:56:04.800 --> 01:56:06.800] What's the object? [01:56:06.800 --> 01:56:09.640] Those kind of tasks, it does not matter much. [01:56:09.640 --> 01:56:13.440] But there are two kinds of tasks that unfreezing the vision [01:56:13.440 --> 01:56:14.960] encoder really matters. [01:56:14.960 --> 01:56:17.560] One is that it's not necessarily about the semantics. [01:56:17.560 --> 01:56:22.000] For example, I'm asking whether this line is straight. [01:56:22.000 --> 01:56:23.760] Those kind of tasks which require [01:56:23.760 --> 01:56:25.560] you to understand the low-level details [01:56:25.560 --> 01:56:27.800] or the low-level detail really matters. [01:56:27.800 --> 01:56:29.760] It's one of the things that-- [01:56:29.760 --> 01:56:32.500] and we also have another work, VIP Lava, [01:56:32.500 --> 01:56:35.680] where we try to train the model to understand [01:56:35.680 --> 01:56:36.800] the visual prompts. [01:56:36.800 --> 01:56:38.920] So basically, the visual prompts, we mean that, [01:56:38.920 --> 01:56:42.400] can we just use some scribble to circle some objects [01:56:42.400 --> 01:56:44.920] that we want to ask about instead of necessarily trying [01:56:44.920 --> 01:56:49.400] to describe it very clearly on making the model to understand [01:56:49.400 --> 01:56:51.160] what we are curious about? [01:56:51.160 --> 01:56:55.000] So for that, in order to correctly identify [01:56:55.000 --> 01:57:01.280] those scribbles and those tiny lines, [01:57:01.280 --> 01:57:04.800] it requires you to somehow unfreeze the vision encoder [01:57:04.800 --> 01:57:10.280] to properly unfreeze or use some earlier layers, which still [01:57:10.280 --> 01:57:11.840] preserves those information. [01:57:11.840 --> 01:57:14.200] I'm curious about the backstory behind this whole thing. [01:57:14.200 --> 01:57:17.280] How did you get started exploring multimodality [01:57:17.280 --> 01:57:19.120] and your inspiration? [01:57:19.120 --> 01:57:22.720] We have been working on visual language since-- [01:57:22.720 --> 01:57:25.320] our team has been also working on visual language and-- [01:57:25.320 --> 01:57:26.480] Your team, is it a lab? [01:57:26.480 --> 01:57:29.720] Is it a-- his team, I see. [01:57:29.720 --> 01:57:33.360] Chun-Yu from Microsoft and we and my advisor, [01:57:33.360 --> 01:57:36.520] we have a collaborative effort on this. [01:57:36.520 --> 01:57:39.880] We have a series of work on visual language. [01:57:39.880 --> 01:57:46.360] And although I'm not having tons of years' [01:57:46.360 --> 01:57:50.880] experience on visual language, but we do see that-- [01:57:50.880 --> 01:57:55.760] in March, we see Vicuna, which makes us very impressed [01:57:55.760 --> 01:57:57.720] about the performance it can have. [01:57:57.720 --> 01:57:58.440] For the size. [01:57:58.440 --> 01:58:00.960] Yeah, for the size and also for the open source. [01:58:00.960 --> 01:58:03.320] And we believe that it's possible for us [01:58:03.320 --> 01:58:07.400] to create a visual reasoning model that [01:58:07.400 --> 01:58:11.040] is purely open source with similar level of capability. [01:58:11.040 --> 01:58:13.960] And we believe that with open source, [01:58:13.960 --> 01:58:16.840] we are able to have a joint effort from the community [01:58:16.840 --> 01:58:18.720] to make it much, much better. [01:58:18.720 --> 01:58:19.800] Yeah, it was cheap, right? [01:58:19.800 --> 01:58:21.520] You trained for like eight hours. [01:58:21.520 --> 01:58:26.080] Yeah, eight hours or 11.5 one day on a single node. [01:58:26.080 --> 01:58:27.880] That means everyone else can do it too. [01:58:27.880 --> 01:58:31.400] Yes, not everyone else, but most people. [01:58:31.400 --> 01:58:32.400] Thank you very much. [01:58:32.400 --> 01:58:35.200] So super interesting and notable work on the Lava model. [01:58:35.200 --> 01:58:37.080] I guess someone should try to hire him. [01:58:37.080 --> 01:58:41.040] But I guess the next segment we're going to explore [01:58:41.040 --> 01:58:44.520] is the prompting segment, quote unquote. [01:58:44.520 --> 01:58:46.920] And there are a surprising number of prompting papers [01:58:46.920 --> 01:58:48.120] here. [01:58:48.120 --> 01:58:50.360] I'm not sure that many papers should [01:58:50.360 --> 01:58:52.080] be represented at NeurIPS. [01:58:52.080 --> 01:58:53.760] But where else are they going to present? [01:58:53.760 --> 01:58:54.720] I don't really know. [01:58:54.720 --> 01:58:58.560] But anyway, so there was a whole channel or track just [01:58:58.560 --> 01:58:59.800] a chain of thought. [01:58:59.800 --> 01:59:02.720] That blows my mind to me. [01:59:02.720 --> 01:59:04.760] And I do think that that is appropriate. [01:59:04.760 --> 01:59:07.480] And I do think that the techniques here are innovative. [01:59:07.480 --> 01:59:09.040] It's impossible to cover all of them. [01:59:09.040 --> 01:59:11.200] I actually talked to Noah Shin from Reflection. [01:59:11.200 --> 01:59:12.720] Remember Reflection? [01:59:12.720 --> 01:59:14.320] As well as a whole bunch of others. [01:59:14.320 --> 01:59:16.800] But probably the most representative one [01:59:16.800 --> 01:59:18.400] was the Tree of Thought paper. [01:59:18.400 --> 01:59:19.160] So here it is. [01:59:19.160 --> 01:59:20.960] My name is Shen Yu. [01:59:20.960 --> 01:59:21.840] I'm from Princeton. [01:59:21.840 --> 01:59:24.120] I'm very excited to talk about Tree of Thoughts. [01:59:24.120 --> 01:59:25.640] It's a joint work with my colleagues [01:59:25.640 --> 01:59:29.640] from Princeton and Google. [01:59:29.640 --> 01:59:32.720] So we all know language models and large-ended models. [01:59:32.720 --> 01:59:35.000] Language models were invented to generate text, token [01:59:35.000 --> 01:59:37.640] by token, and left to right. [01:59:37.640 --> 01:59:41.160] But now they are used to solve an increasingly wide range [01:59:41.160 --> 01:59:44.640] of problems using scale-up models and prompting [01:59:44.640 --> 01:59:47.000] techniques like chain of thought. [01:59:47.000 --> 01:59:48.360] So here is an example. [01:59:48.360 --> 01:59:51.800] You can break down complex calculation into steps, [01:59:51.800 --> 01:59:54.200] and it will make it solve problems [01:59:54.200 --> 01:59:55.840] that cannot solve in steps. [01:59:55.840 --> 01:59:59.960] So the question is, can those language models one day [01:59:59.960 --> 02:00:02.680] become a general problem solver by keep scaling up and using [02:00:02.680 --> 02:00:04.120] autoregressive inference? [02:00:04.120 --> 02:00:07.440] Or there are some fundamental limitations. [02:00:07.440 --> 02:00:10.160] So to answer the question, let's take a look [02:00:10.160 --> 02:00:11.960] at a very simple example. [02:00:11.960 --> 02:00:16.320] This game of 24, where the rule is you are given four numbers. [02:00:16.320 --> 02:00:20.960] And you have plus, minus, divide, and multiply operations. [02:00:20.960 --> 02:00:24.680] And you need to combine those four numbers to open 24. [02:00:24.680 --> 02:00:32.640] So one example is if you are given input 2, 9, 10, and 12, [02:00:32.640 --> 02:00:36.720] what you can do is you can first multiply 12 and 2 to get 24, [02:00:36.720 --> 02:00:41.480] then 10 minus 9 to get 1, then 24 times 1 to get 24. [02:00:41.480 --> 02:00:43.560] So it's not a really hard game. [02:00:43.560 --> 02:00:47.760] Now you give it a new input, 4, 5, 6, 10, to GPT 3.5, [02:00:47.760 --> 02:00:49.240] and that will solve the task. [02:00:49.240 --> 02:00:52.480] It will first try to multiply 2, 9, 6 to get 60, [02:00:52.480 --> 02:00:56.080] then divide it by 5 to get 12, then 12 times 4 to get 48. [02:00:56.080 --> 02:00:59.960] Then to make it up, it will say it's 24 and then call it a day. [02:00:59.960 --> 02:01:02.360] So it's a hallucination. [02:01:02.360 --> 02:01:05.240] You might argue that if you have better models or better prompts [02:01:05.240 --> 02:01:06.040] it will solve this. [02:01:06.040 --> 02:01:09.920] But even if you use GPT-4 with five example, [02:01:09.920 --> 02:01:13.800] it's in this ALT prompt, it will only get 4% task success. [02:01:13.800 --> 02:01:17.560] So why is this easy task so hard for language models? [02:01:17.560 --> 02:01:23.240] So if you look at the initial token generation, 10 and 6, [02:01:23.240 --> 02:01:26.720] 10 times, because those language models [02:01:26.720 --> 02:01:30.240] are making local and token-level decisions, one by one, [02:01:30.240 --> 02:01:34.840] left to right, those initial decisions are really hard. [02:01:34.840 --> 02:01:37.400] Even for humans, we don't know whether the first token should [02:01:37.400 --> 02:01:38.920] be 10 or 6 or 5. [02:01:38.920 --> 02:01:41.480] We don't have pre-trained intuition. [02:01:41.480 --> 02:01:44.360] We have to play the game to have a better sense. [02:01:44.360 --> 02:01:46.280] Worse still, once you generate those one token [02:01:46.280 --> 02:01:48.160] at the beginning, the task is already [02:01:48.160 --> 02:01:51.400] filled in that you cannot really complete the whole triage [02:01:51.400 --> 02:01:54.200] in a CLT format and be right. [02:01:54.200 --> 02:01:57.080] So by this very simple example, what I want to show [02:01:57.080 --> 02:02:00.880] is there is something about autoregressive inference [02:02:00.880 --> 02:02:04.600] that is lacking mechanisms for deliberate reasoning. [02:02:04.600 --> 02:02:10.000] So it's even true for biggest, strongest language models [02:02:10.000 --> 02:02:11.120] like GPT-4. [02:02:11.120 --> 02:02:12.760] And the reason is quite simple, just [02:02:12.760 --> 02:02:13.920] like Ben's talk mentioned. [02:02:13.920 --> 02:02:16.680] So for the CLT to work, you really [02:02:16.680 --> 02:02:19.140] need strong local signals to guide every step [02:02:19.140 --> 02:02:20.840] through those local decisions. [02:02:20.840 --> 02:02:23.220] And just to draw analog, imagine if you [02:02:23.220 --> 02:02:28.040] have a robot that's trained only on successful navigation [02:02:28.040 --> 02:02:29.400] trajectories. [02:02:29.400 --> 02:02:31.800] And it's only trained to predict the next move. [02:02:31.800 --> 02:02:33.440] And then you put it into a new maze. [02:02:33.440 --> 02:02:36.280] And then it's very hard to explore. [02:02:36.280 --> 02:02:39.360] So how do we solve this issue? [02:02:39.360 --> 02:02:40.740] So in this work, we took inspiration [02:02:40.740 --> 02:02:42.920] from human cognition. [02:02:42.920 --> 02:02:46.320] In his famous book, Thinking Fast and Slow, [02:02:46.320 --> 02:02:49.640] Daniel Kahneman proposed that our cognition has two parts. [02:02:49.640 --> 02:02:51.880] We have a fast and automatic system [02:02:51.880 --> 02:02:54.520] one that's handling everyday tasks, like riding a bike. [02:02:54.520 --> 02:02:56.320] And we have a slow and deliberate system two [02:02:56.320 --> 02:02:59.320] that's imposing control and intervention over system one [02:02:59.320 --> 02:03:02.640] for harder tasks, like designing a plan. [02:03:02.640 --> 02:03:07.080] So if language models' autoregressive inference [02:03:07.080 --> 02:03:11.360] is similar to this spontaneous but error-prone system one [02:03:11.360 --> 02:03:14.480] process, maybe we can impose some kind of control algorithm [02:03:14.480 --> 02:03:17.880] on top of it to get system two reasoning. [02:03:17.880 --> 02:03:20.920] And tree search is naturally the choice, [02:03:20.920 --> 02:03:22.680] which is also one of the oldest ideas [02:03:22.680 --> 02:03:23.840] in artificial intelligence. [02:03:23.840 --> 02:03:27.200] For example, the Wiley-Simons General Problem Solver [02:03:27.200 --> 02:03:29.920] in the 1950s. [02:03:29.920 --> 02:03:34.280] However, doing search in this reasoning space [02:03:34.280 --> 02:03:37.360] is non-trivial, because traditionally, [02:03:37.360 --> 02:03:39.800] if we search in classical games, like chess, [02:03:39.800 --> 02:03:43.360] we often have a small fixed set of next moves [02:03:43.360 --> 02:03:47.120] so that we can design or learn search heuristics. [02:03:47.120 --> 02:03:49.800] But if we want to search in open-ended reasoning, [02:03:49.800 --> 02:03:51.720] the next move can be arbitrary tasks, [02:03:51.720 --> 02:03:55.120] which is really hard to enumerate or evaluate. [02:03:55.120 --> 02:03:58.840] So the idea here is now that we have large language models, [02:03:58.840 --> 02:04:00.880] we can use them to start generating and evaluating [02:04:00.880 --> 02:04:02.280] next moves. [02:04:02.280 --> 02:04:04.080] So from the next previous two slides, [02:04:04.080 --> 02:04:06.680] you have seen what's the problem of large language models [02:04:06.680 --> 02:04:08.560] and what's the problem of classical search. [02:04:08.560 --> 02:04:11.800] And the hint of combining them might lead to a better result. [02:04:11.800 --> 02:04:12.800] And that's true. [02:04:12.800 --> 02:04:14.560] So we propose Tree of Thoughts. [02:04:14.560 --> 02:04:16.880] It's a general method for combining language models [02:04:16.880 --> 02:04:19.720] and search algorithms for deliberate reasoning. [02:04:19.720 --> 02:04:21.920] And to solve a problem, you need four parts. [02:04:21.920 --> 02:04:24.080] So first, you need to define what is a search space [02:04:24.080 --> 02:04:26.280] or what is a thought space. [02:04:26.280 --> 02:04:28.800] Then you need to generate and evaluate [02:04:28.800 --> 02:04:30.360] thoughts using language models. [02:04:30.360 --> 02:04:32.520] And you need to combine that with a search algorithm [02:04:32.520 --> 02:04:33.880] to explore and maintain thoughts. [02:04:33.880 --> 02:04:36.480] So I'll use the simplest example, which is Game of 24, [02:04:36.480 --> 02:04:38.280] to explain each part. [02:04:38.280 --> 02:04:40.280] OK, so what is a thought? [02:04:40.280 --> 02:04:41.920] That's not a question in Tree of Thoughts [02:04:41.920 --> 02:04:44.520] because everything is coherent, and you [02:04:44.520 --> 02:04:45.720] don't have to split it. [02:04:45.720 --> 02:04:48.280] But it's a very critical thing in Tree of Thoughts. [02:04:48.280 --> 02:04:51.160] So here, we define a thought as a coherent piece of text. [02:04:51.160 --> 02:04:53.240] That's the next move in the reasoning game. [02:04:53.240 --> 02:04:55.600] And if you think about Game of 24, [02:04:55.600 --> 02:04:57.600] there are two extreme choices. [02:04:57.600 --> 02:05:01.320] On one extreme, you can treat each token as a thought. [02:05:01.320 --> 02:05:03.960] Then it will be very easy to generate each thought. [02:05:03.960 --> 02:05:05.720] But as explained before, it's very hard [02:05:05.720 --> 02:05:10.000] to evaluate whether 10 is a good thought or 13 is a good thought. [02:05:10.000 --> 02:05:12.280] On the other extreme, you can treat the whole reasoning [02:05:12.280 --> 02:05:13.240] as a thought. [02:05:13.240 --> 02:05:14.600] You generate the whole thing, which [02:05:14.600 --> 02:05:15.840] will be very easy to evaluate. [02:05:15.840 --> 02:05:18.760] You just look at the end if the number is 24. [02:05:18.760 --> 02:05:21.840] But if you can generate that, the problem is solved already. [02:05:21.840 --> 02:05:24.600] So it's very hard to generate. [02:05:24.600 --> 02:05:27.440] So in this game, naturally, the choice of thought [02:05:27.440 --> 02:05:29.000] is something in between. [02:05:29.000 --> 02:05:30.640] We can use each intermediate equation [02:05:30.640 --> 02:05:33.880] as a thought so that it's relatively easy to generate [02:05:33.880 --> 02:05:35.340] and evaluate thoughts. [02:05:35.340 --> 02:05:38.000] And this is really a problem-specific trade-off [02:05:38.000 --> 02:05:38.640] design. [02:05:38.640 --> 02:05:40.960] So for different problems, a thought can be a token, [02:05:40.960 --> 02:05:42.720] can be a word, can be a sentence, [02:05:42.720 --> 02:05:46.800] can be a paragraph, and so on. [02:05:46.800 --> 02:05:49.040] So once you have defined what is a thought, [02:05:49.040 --> 02:05:51.040] it's easy to generate that with language models. [02:05:51.040 --> 02:05:53.080] So here, it's a simple prompt. [02:05:53.080 --> 02:05:55.200] We have one example of what's the input [02:05:55.200 --> 02:05:57.200] and what's the possible thoughts. [02:05:57.200 --> 02:05:59.760] Then you give it a new input, and the language model [02:05:59.760 --> 02:06:02.000] just can generate a new thought. [02:06:02.000 --> 02:06:03.680] So here, each new line is a new thought [02:06:03.680 --> 02:06:05.480] of how to continue the reasoning. [02:06:05.480 --> 02:06:09.360] Once you have those thoughts, you [02:06:09.360 --> 02:06:12.720] want to give them a value so that you can search. [02:06:12.720 --> 02:06:16.880] So here, what we do is we give this prompt of example [02:06:16.880 --> 02:06:19.200] where, for the remaining numbers, [02:06:19.200 --> 02:06:21.200] if the language model can simulate within a field [02:06:21.200 --> 02:06:23.880] trials and reach 24, then a high value is given. [02:06:23.880 --> 02:06:27.800] If not, depending on whether the numbers look reasonable or not, [02:06:27.800 --> 02:06:30.560] a medium or a low value is given. [02:06:30.560 --> 02:06:34.320] So the previous three examples are the in-context examples. [02:06:34.320 --> 02:06:38.360] Now, for this new input, 5, 6, 6, the language model [02:06:38.360 --> 02:06:42.040] try one round, find 24, and sure, it's a high value. [02:06:42.040 --> 02:06:47.440] For turn 13, 13, it will try a few rounds, and it will fail. [02:06:47.440 --> 02:06:49.480] And these numbers look too large, [02:06:49.480 --> 02:06:52.480] so it's impossible, so a low value. [02:06:52.480 --> 02:06:55.360] So for something like 5, 5, 9, it try a few rounds. [02:06:55.360 --> 02:06:58.280] It doesn't work, but the numbers look reasonable, so likely, [02:06:58.280 --> 02:07:00.520] so a medium value. [02:07:00.520 --> 02:07:02.400] But here, actually, 5, 5, 9, it's [02:07:02.400 --> 02:07:04.280] not actually possible to reach 24. [02:07:04.280 --> 02:07:07.280] So it's important to know that, just like any search heuristics, [02:07:07.280 --> 02:07:09.920] here, the value does not have to be perfect. [02:07:09.920 --> 02:07:11.680] And it just needs to bias the search [02:07:11.680 --> 02:07:13.720] toward promising directions. [02:07:13.720 --> 02:07:16.720] Also, here, the prompt uses common sense reasoning [02:07:16.720 --> 02:07:18.880] and simulation, but you can really [02:07:18.880 --> 02:07:21.360] design different strategies for different problems. [02:07:21.360 --> 02:07:23.560] It's really flexible. [02:07:23.560 --> 02:07:25.180] So lastly, you can combine them together [02:07:25.180 --> 02:07:26.440] with a tree-search algorithm. [02:07:26.440 --> 02:07:27.840] Here, we use breadth-first search, [02:07:27.840 --> 02:07:29.960] which is the simplest algorithm. [02:07:29.960 --> 02:07:34.400] You have a depth of 3, and you have a breadth from 1 until 5. [02:07:34.400 --> 02:07:35.640] And the idea is very simple. [02:07:35.640 --> 02:07:36.480] You have the input. [02:07:36.480 --> 02:07:37.680] You generate a bunch of sorts. [02:07:37.680 --> 02:07:38.640] You evaluate them. [02:07:38.640 --> 02:07:40.500] You only keep the top choices, so it's [02:07:40.500 --> 02:07:43.160] like a sort-level Bing search. [02:07:43.160 --> 02:07:45.560] And you keep doing that until four numbers [02:07:45.560 --> 02:07:47.340] become three numbers, three numbers become two numbers, [02:07:47.340 --> 02:07:48.760] and two numbers become one number. [02:07:48.760 --> 02:07:51.880] And you're succeeded if the only number is 24. [02:07:51.880 --> 02:07:56.800] So while CLT only achieved 4%, TLT with a breadth of 1 [02:07:56.800 --> 02:07:59.200] already leads to 45, and a breadth of 5 [02:07:59.200 --> 02:08:02.520] leads to an even higher 74. [02:08:02.520 --> 02:08:04.840] We can also use a similar idea for different algorithms [02:08:04.840 --> 02:08:06.120] and for different problems. [02:08:06.120 --> 02:08:09.240] So for example, for crosswords, right, [02:08:09.240 --> 02:08:12.840] suppose you have five clues horizontally and five clues [02:08:12.840 --> 02:08:13.800] vertically. [02:08:13.800 --> 02:08:16.440] What you can do is you can generate a bunch of sorts, [02:08:16.440 --> 02:08:19.680] evaluate them, and then get proceeded only [02:08:19.680 --> 02:08:22.120] with the most promising choice. [02:08:22.120 --> 02:08:24.960] So that's a depth-first search, or breadth-first search. [02:08:24.960 --> 02:08:27.520] And you can keep doing this until the language model [02:08:27.520 --> 02:08:30.440] realize this board is no longer solvable. [02:08:30.440 --> 02:08:32.760] Then what you do is you prune the subtree, [02:08:32.760 --> 02:08:34.000] and then you backtrack, right? [02:08:34.000 --> 02:08:36.680] So you move on to this, but maybe this is still not [02:08:36.680 --> 02:08:38.560] solvable, so you will move on again. [02:08:38.560 --> 02:08:41.200] If none of the same works, then you go back one level back, [02:08:41.200 --> 02:08:42.200] and then you try again. [02:08:42.200 --> 02:08:44.840] So it's a very classic depth-first search. [02:08:44.840 --> 02:08:48.200] And here are the results. [02:08:48.200 --> 02:08:51.160] COT reached 1%, and we got 20%. [02:08:51.160 --> 02:08:52.960] But if you don't have pruning and backtrack, [02:08:52.960 --> 02:08:57.960] then it again goes to 5%, which shows pruning and backtracking [02:08:57.960 --> 02:08:59.840] is very important. [02:08:59.840 --> 02:09:03.040] So in our paper, we have these two games, [02:09:03.040 --> 02:09:05.520] but we also have a natural language task that's [02:09:05.520 --> 02:09:07.040] trying to write creative stories. [02:09:07.040 --> 02:09:08.960] And the intuition is also very simple, right? [02:09:08.960 --> 02:09:11.480] So if you're a good writer, you don't just [02:09:11.480 --> 02:09:12.800] write token by token, right? [02:09:12.800 --> 02:09:15.980] You will deliberately plan what are the possible plots. [02:09:15.980 --> 02:09:18.680] You will choose, compare between them, and you will select them. [02:09:18.680 --> 02:09:20.320] So similarly here, the language model [02:09:20.320 --> 02:09:23.480] will write a bunch of diverse plans, [02:09:23.480 --> 02:09:26.320] then self-evaluate what is a good plan, [02:09:26.320 --> 02:09:27.600] then proceed with that. [02:09:27.600 --> 02:09:30.800] And you can do this kind of search for writing, [02:09:30.800 --> 02:09:36.800] and then humans will find it more creative than the CLT [02:09:36.800 --> 02:09:37.440] writing. [02:09:37.440 --> 02:09:40.080] But the writing is too complex and long, [02:09:40.080 --> 02:09:42.280] so I cannot display it here. [02:09:42.280 --> 02:09:45.520] So what I want to say is across those different tasks [02:09:45.520 --> 02:09:47.920] with very different reasoning challenges, [02:09:47.920 --> 02:09:50.560] the modular design of TOT allows us [02:09:50.560 --> 02:09:53.760] to have very flexible ways to generate, evaluate, [02:09:53.760 --> 02:09:57.400] and search thoughts across very general and diverse tasks. [02:09:57.400 --> 02:10:00.760] And we're doing so in a very systematic framework [02:10:00.760 --> 02:10:03.400] and achieve very good performances [02:10:03.400 --> 02:10:04.700] without retraining any models. [02:10:04.700 --> 02:10:07.280] So it's very convenient to use. [02:10:07.280 --> 02:10:11.120] So we believe this is an initial step toward connecting [02:10:11.120 --> 02:10:13.320] old insights and new frontiers of AI. [02:10:13.320 --> 02:10:17.280] So here, TreeSearch, one of the oldest ideas in AI, [02:10:17.280 --> 02:10:20.040] helps language model do more deliberate reasoning, [02:10:20.040 --> 02:10:22.620] while language models help search, provide search [02:10:22.620 --> 02:10:27.400] with very flexible and general purpose powerful heuristics. [02:10:27.400 --> 02:10:28.920] So we have these follow-up efforts [02:10:28.920 --> 02:10:31.880] trying to connect cognitive architectures to language [02:10:31.880 --> 02:10:33.360] model-based agents. [02:10:33.360 --> 02:10:37.520] So those are systems that does not just reason internally, [02:10:37.520 --> 02:10:39.900] but also interact with the external world [02:10:39.900 --> 02:10:42.120] and learn through such interaction continuously. [02:10:42.120 --> 02:10:45.160] So it's like autonomous agents. [02:10:45.160 --> 02:10:46.840] So we have this follow-up paper called [02:10:46.840 --> 02:10:50.520] Koala, Cognitive Architectures for Language Agents. [02:10:50.520 --> 02:10:54.160] I highly encourage you guys to check it out. [02:10:54.160 --> 02:10:56.600] And I thank my co-authors. [02:10:56.600 --> 02:10:58.440] Thank you guys for listening. [02:10:58.440 --> 02:10:59.560] Check out the poster today. [02:10:59.560 --> 02:11:00.840] And happy to chat. [02:11:00.840 --> 02:11:01.800] Thank you so much. [02:11:01.800 --> 02:11:05.120] [APPLAUSE] [02:11:05.120 --> 02:11:08.920] I do like it when people come up with a general enough model [02:11:08.920 --> 02:11:10.760] that you can customize it and specialize it [02:11:10.760 --> 02:11:14.320] to recover smaller effects that other people have found. [02:11:14.320 --> 02:11:17.280] So you can, from the tree of thought paper, [02:11:17.280 --> 02:11:20.000] recover something like the backspace token model, [02:11:20.000 --> 02:11:23.320] or recover a skeleton of thought, chain of thought, [02:11:23.320 --> 02:11:24.600] whatever of thought. [02:11:24.600 --> 02:11:25.520] I don't care. [02:11:25.520 --> 02:11:26.760] I can't keep track anymore. [02:11:26.760 --> 02:11:28.760] Anyway, so I caught up with him at his poster session. [02:11:28.760 --> 02:11:30.000] And here's a bit of our chat. [02:11:30.000 --> 02:11:30.680] Yeah, you can hold it up. [02:11:30.680 --> 02:11:31.180] All right. [02:11:31.180 --> 02:11:32.080] Go ahead. [02:11:32.080 --> 02:11:34.760] So the TL;DR of this paper is very simple. [02:11:34.760 --> 02:11:36.320] Large-language models and search, [02:11:36.320 --> 02:11:38.440] they complement each other. [02:11:38.440 --> 02:11:40.560] So what's wrong with just using large-language models [02:11:40.560 --> 02:11:42.840] without search? [02:11:42.840 --> 02:11:45.920] Is everyone familiar with chain of thought? [02:11:45.920 --> 02:11:48.960] So suppose you're trying to solve this game of 24, [02:11:48.960 --> 02:11:53.640] where given four numbers, you try to combine them to get 24. [02:11:53.640 --> 02:11:57.400] So you can give GPT-4 this task instruction. [02:11:57.400 --> 02:12:00.640] You give it a couple of CLT examples. [02:12:00.640 --> 02:12:02.060] But the performance is really low. [02:12:02.060 --> 02:12:03.280] It's 4%. [02:12:03.280 --> 02:12:04.720] Why is it so hard? [02:12:04.720 --> 02:12:11.400] So that's because this problem intrinsically needs exploration. [02:12:11.400 --> 02:12:14.280] So let's take a look at the initial example. [02:12:14.280 --> 02:12:18.640] So the model is making local token decisions. [02:12:18.640 --> 02:12:20.280] It first generates 10. [02:12:20.280 --> 02:12:22.280] Then it generates times. [02:12:22.280 --> 02:12:24.160] Then it generates 6. [02:12:24.160 --> 02:12:27.080] But it's very hard to decide those initial tokens, [02:12:27.080 --> 02:12:29.040] even for humans. [02:12:29.040 --> 02:12:31.240] You don't really know whether the first token should [02:12:31.240 --> 02:12:33.400] be 10, or 5, or 6, or are they equally good. [02:12:33.400 --> 02:12:35.520] That's really hard to decide. [02:12:35.520 --> 02:12:38.240] But what's worse is once you decide [02:12:38.240 --> 02:12:40.320] the wrong tokens at the beginning, [02:12:40.320 --> 02:12:42.400] the task is already failed. [02:12:42.400 --> 02:12:47.240] So in this particular example, if you generate 10 and times, [02:12:47.240 --> 02:12:48.880] this task is already failed. [02:12:48.880 --> 02:12:51.840] Because no matter what times 10, you [02:12:51.840 --> 02:12:56.240] cannot get three numbers remaining to reach 24. [02:12:56.240 --> 02:13:02.840] So the intuition is that autoregressive inference [02:13:02.840 --> 02:13:06.760] is like you're keeping making those local token decisions [02:13:06.760 --> 02:13:09.920] one by one, left to right, without look ahead, [02:13:09.920 --> 02:13:11.440] without backtrack. [02:13:11.440 --> 02:13:13.960] And it's not very robust when you [02:13:13.960 --> 02:13:16.240] don't have good local signals to guide [02:13:16.240 --> 02:13:18.280] through those kind of process. [02:13:18.280 --> 02:13:20.880] So another analogy would be suppose [02:13:20.880 --> 02:13:25.920] you're training a robot that's trying to navigate Macy's. [02:13:25.920 --> 02:13:29.680] If you only train them on successful trajectories, [02:13:29.680 --> 02:13:34.480] and you only train them to predict the next move, [02:13:34.480 --> 02:13:37.040] and you do this local imitation, and you put them [02:13:37.040 --> 02:13:40.400] in a new maze that requires exploration, [02:13:40.400 --> 02:13:43.720] then it probably won't solve the new maze. [02:13:43.720 --> 02:13:47.040] So obviously, some kind of search is needed. [02:13:47.040 --> 02:13:50.280] But why this is a 2023 work, given [02:13:50.280 --> 02:13:54.680] that search has been around since 1940s, 1950s? [02:13:54.680 --> 02:13:59.160] That's because classical search problems, like chess, [02:13:59.160 --> 02:14:02.320] they have a small fixed set of next moves, [02:14:02.320 --> 02:14:04.480] what we call the search space. [02:14:04.480 --> 02:14:07.560] That makes it easy to define, to design, [02:14:07.560 --> 02:14:11.880] or to learn search heuristics to guide the search. [02:14:11.880 --> 02:14:15.360] But here, for this kind of open-ended reasoning, [02:14:15.360 --> 02:14:17.960] the next move can be anything. [02:14:17.960 --> 02:14:19.640] It could be a token. [02:14:19.640 --> 02:14:21.400] It could be a sentence. [02:14:21.400 --> 02:14:23.560] It could be a paragraph. [02:14:23.560 --> 02:14:27.240] And it's impossible to enumerate this huge space, [02:14:27.240 --> 02:14:29.480] or to design evaluations. [02:14:29.480 --> 02:14:33.640] So the key point here is you want [02:14:33.640 --> 02:14:37.840] to really define what is the search space first. [02:14:37.840 --> 02:14:39.460] You can consider two extremes first. [02:14:39.460 --> 02:14:44.800] So on one extreme, you can define a thought [02:14:44.800 --> 02:14:46.480] as the next token. [02:14:46.480 --> 02:14:48.480] Then you will be searching in a tree of tokens, [02:14:48.480 --> 02:14:50.480] something like BeamSearch. [02:14:50.480 --> 02:14:55.520] Then the problem is it's very easy to generate tokens, right? [02:14:55.520 --> 02:14:57.480] But it's very hard to evaluate tokens. [02:14:57.480 --> 02:15:00.600] You don't really know whether 10 is good, or 13 is good, [02:15:00.600 --> 02:15:02.760] or whatever. [02:15:02.760 --> 02:15:08.640] On the other extreme, you can define thought [02:15:08.640 --> 02:15:11.280] as the whole reasoning. [02:15:11.280 --> 02:15:14.560] Then it's very easy to evaluate the thought. [02:15:14.560 --> 02:15:18.280] You just look at if the final number is 24 or not. [02:15:18.280 --> 02:15:20.800] But in this bandit, it will be very hard [02:15:20.800 --> 02:15:23.040] to generate a good thought. [02:15:23.040 --> 02:15:25.760] Otherwise, the task is solved already. [02:15:25.760 --> 02:15:28.280] So in this case, it seems like the right balance [02:15:28.280 --> 02:15:32.060] is you define each of the intermediate steps of thought. [02:15:32.060 --> 02:15:35.440] So you can do something like you tell language model, [02:15:35.440 --> 02:15:36.720] here are some numbers. [02:15:36.720 --> 02:15:40.080] Come up with some different ways to combine two of the numbers. [02:15:40.080 --> 02:15:41.880] You can generate a bunch of thoughts. [02:15:41.880 --> 02:15:44.060] And for each of them, you can do something like this. [02:15:44.060 --> 02:15:45.440] You can say, try a few runs. [02:15:45.440 --> 02:15:47.080] Can you reach 24? [02:15:47.080 --> 02:15:49.840] If not, try to decide a value based on that. [02:15:49.840 --> 02:15:51.520] So we're seeing three trials. [02:15:51.520 --> 02:15:54.280] If you can already reach 24, then this thought [02:15:54.280 --> 02:15:56.200] has very high value. [02:15:56.200 --> 02:15:59.280] If it couldn't reach 24, but maybe it could reach maybe 25 [02:15:59.280 --> 02:16:00.880] or 26, maybe OK. [02:16:00.880 --> 02:16:02.600] Maybe a median value is given. [02:16:02.600 --> 02:16:04.800] But if this is something like 1, 2, 3, [02:16:04.800 --> 02:16:07.960] and then you can only reach 6 or 4, [02:16:07.960 --> 02:16:10.280] then maybe a low value is given. [02:16:10.280 --> 02:16:12.640] So this value is not perfect. [02:16:12.640 --> 02:16:15.320] And it does not need to be perfect, [02:16:15.320 --> 02:16:16.960] just like any search heuristics. [02:16:16.960 --> 02:16:20.200] It just needs to bias the search towards promising directions. [02:16:20.200 --> 02:16:24.640] So the point is, once you define this search space, [02:16:24.640 --> 02:16:27.680] you can generate and evaluate next moves [02:16:27.680 --> 02:16:29.240] using large language models. [02:16:29.240 --> 02:16:31.120] And then you can systematically maintain them [02:16:31.120 --> 02:16:32.600] using the tree search algorithm. [02:16:32.600 --> 02:16:34.640] And we show across diverse tasks, [02:16:34.640 --> 02:16:38.200] this significantly improves the task performances. [02:16:38.200 --> 02:16:39.360] And it's very easy to use. [02:16:39.360 --> 02:16:41.400] You don't need to train any new models. [02:16:41.400 --> 02:16:43.480] Everything is done with GPT-4. [02:16:43.480 --> 02:16:44.720] Pretty elegant. [02:16:44.720 --> 02:16:48.480] So I like that in comparison to beam search. [02:16:48.480 --> 02:16:51.600] This is like a level of abstraction [02:16:51.600 --> 02:16:56.240] above that with the atomic unit being a thought. [02:16:56.240 --> 02:16:59.240] A thought here, you illustrated it being an equation. [02:16:59.240 --> 02:17:05.040] But here you have an equation, a clue word, the examples. [02:17:05.040 --> 02:17:07.000] Do you have a planning stage in order [02:17:07.000 --> 02:17:09.640] to plan out the thought steps? [02:17:09.640 --> 02:17:11.360] Here you have thought steps of 3, [02:17:11.360 --> 02:17:13.960] thought steps of 5 to 10, thought steps of 1. [02:17:13.960 --> 02:17:16.400] Usually, when people design agents, [02:17:16.400 --> 02:17:18.600] they'll have a planner. [02:17:18.600 --> 02:17:21.040] But I don't see a planner here. [02:17:21.040 --> 02:17:22.280] -That's a great question. [02:17:22.280 --> 02:17:26.760] And you will notice here is, for those two games, [02:17:26.760 --> 02:17:28.960] the search steps are kind of homogeneous. [02:17:28.960 --> 02:17:30.540] Because every step, you're just trying [02:17:30.540 --> 02:17:32.360] to come up with a new equation. [02:17:32.360 --> 02:17:34.560] Or you're trying to come up with a new clue. [02:17:34.560 --> 02:17:36.640] So in this case, you don't really need planning. [02:17:36.640 --> 02:17:39.760] You can just use one generation prompt, one evaluation prompt, [02:17:39.760 --> 02:17:41.760] and use that across different steps. [02:17:41.760 --> 02:17:43.560] But for something more complicated, [02:17:43.560 --> 02:17:46.680] where for each search step, you might do different things, [02:17:46.680 --> 02:17:49.280] then you probably need to plan ahead and maybe design [02:17:49.280 --> 02:17:51.760] different prompts for different kinds of generation [02:17:51.760 --> 02:17:53.360] and different kinds of evaluation. [02:17:53.360 --> 02:17:54.720] -Got it. [02:17:54.720 --> 02:17:57.360] So do you also see this being able to be [02:17:57.360 --> 02:17:59.180] combined with self-consistency? [02:17:59.180 --> 02:18:03.880] Because in a way, your judge is a self-consistent-- [02:18:03.880 --> 02:18:05.360] -That's a great idea. [02:18:05.360 --> 02:18:06.920] And we did that. [02:18:06.920 --> 02:18:12.320] The point here is, in this creative writing task, [02:18:12.320 --> 02:18:18.440] what we do for evaluation is, here's a task instruction. [02:18:18.440 --> 02:18:21.080] Here are some of the plans. [02:18:21.080 --> 02:18:23.000] Think step-by-step what is the best plan [02:18:23.000 --> 02:18:25.360] and come up with an idea. [02:18:25.360 --> 02:18:31.040] So if you just do this one time, you'll just get one vote. [02:18:31.040 --> 02:18:32.760] It's very noisy. [02:18:32.760 --> 02:18:34.760] So you can apply something like self-consistency. [02:18:34.760 --> 02:18:37.760] You can ID, do like 10 different votings, [02:18:37.760 --> 02:18:39.160] or 100 different votings. [02:18:39.160 --> 02:18:42.840] And then the evaluation will become more faithful. [02:18:42.840 --> 02:18:45.000] And that's kind of hyperparameter you can choose. [02:18:45.000 --> 02:18:46.920] It's like, if you want better performance, [02:18:46.920 --> 02:18:48.920] you can spend more money and try to do that more. [02:18:48.920 --> 02:18:50.840] -It's like a post-generation layer? [02:18:50.840 --> 02:18:53.680] -It's like a stepwise democracy, I guess. [02:18:53.680 --> 02:18:57.800] -OK, so one more question about, just in general, Princeton NLP. [02:18:57.800 --> 02:18:58.960] How is it organized? [02:18:58.960 --> 02:19:03.960] What should people know about the Princeton program? [02:19:03.960 --> 02:19:06.080] Because I feel like you guys are very productive. [02:19:06.080 --> 02:19:08.760] And how are you so productive? [02:19:08.760 --> 02:19:10.840] What's the back story to your thoughts, maybe? [02:19:10.840 --> 02:19:12.920] -I think one thing that's good about Princeton [02:19:12.920 --> 02:19:17.760] is it's a kind of small school. [02:19:17.760 --> 02:19:18.480] -I've been there. [02:19:18.480 --> 02:19:19.320] It's not that small. [02:19:19.320 --> 02:19:22.320] -I mean, compared to Harvard or MIT. [02:19:22.320 --> 02:19:25.240] And you have a lot of interdisciplinary collaborations. [02:19:25.240 --> 02:19:28.520] So I did this with cognitive science professors. [02:19:28.520 --> 02:19:31.520] I think this kind of idea across different fields [02:19:31.520 --> 02:19:32.360] is very important. [02:19:32.360 --> 02:19:36.800] So usually in NLP, we don't consider tasks like that. [02:19:36.800 --> 02:19:38.320] That's classical search, right? [02:19:38.320 --> 02:19:40.880] So I think it's very useful to combine ideas [02:19:40.880 --> 02:19:41.800] from different fields. [02:19:41.800 --> 02:19:46.760] And that could be a way to come up with new ideas. [02:19:46.760 --> 02:19:50.240] -Are a lot of people asking you about Q* stuff? [02:19:50.240 --> 02:19:51.040] -No. [02:19:51.040 --> 02:19:51.640] -No comments? [02:19:51.640 --> 02:19:52.640] -No comments. [02:19:52.640 --> 02:19:53.880] -OK, well, thank you very much. [02:19:53.880 --> 02:19:55.200] This is a great paper. [02:19:55.200 --> 02:19:57.760] -So perhaps one paper that made a bigger [02:19:57.760 --> 02:20:00.360] splash than Tria thought earlier in this year [02:20:00.360 --> 02:20:03.160] was Toolformer, where we started really [02:20:03.160 --> 02:20:05.640] considering the myriad number of ways [02:20:05.640 --> 02:20:07.800] that we can train language models to use tools. [02:20:07.800 --> 02:20:09.560] So here's the Toolformer oral. [02:20:09.560 --> 02:20:10.240] -Hi, everyone. [02:20:10.240 --> 02:20:11.080] My name is Jane. [02:20:11.080 --> 02:20:13.120] I'm a researcher from FAIR Labs at MEDA. [02:20:13.120 --> 02:20:15.400] And today, I'm super excited to be presenting to you [02:20:15.400 --> 02:20:19.800] Toolformer, Language Models Can Teach Themselves to Use Tools. [02:20:19.800 --> 02:20:22.080] And the reason we might want language models, [02:20:22.080 --> 02:20:24.560] like chat GPT, to have access to external tools [02:20:24.560 --> 02:20:27.440] is exemplified by these three queries. [02:20:27.440 --> 02:20:28.960] In the first two cases, I've asked [02:20:28.960 --> 02:20:31.240] who is the current president and what day of the week [02:20:31.240 --> 02:20:32.040] is it today? [02:20:32.040 --> 02:20:33.960] And chat GPT basically says it doesn't [02:20:33.960 --> 02:20:36.720] have real-time data or access to current time or date [02:20:36.720 --> 02:20:38.080] information. [02:20:38.080 --> 02:20:39.760] And in the final query, I've asked [02:20:39.760 --> 02:20:41.960] to do a simple set of computations. [02:20:41.960 --> 02:20:44.600] But chat GPT, unfortunately, hallucinates an answer [02:20:44.600 --> 02:20:47.400] that's about 300 off from the real answer. [02:20:47.400 --> 02:20:49.120] And what we really could have used here [02:20:49.120 --> 02:20:51.080] is access to external tools. [02:20:51.080 --> 02:20:54.720] For example, a QA system that has up-to-date information, [02:20:54.720 --> 02:20:57.240] a calendar tool, which has the time or date, [02:20:57.240 --> 02:20:59.880] and a calculator tool, which is designed specifically [02:20:59.880 --> 02:21:03.520] to do these simple computations perfectly. [02:21:03.520 --> 02:21:06.360] And so for Toolformer, we have five tools at our disposal. [02:21:06.360 --> 02:21:09.080] We have a QA system with up-to-date information. [02:21:09.080 --> 02:21:10.760] We have a Wikipedia search tool, which [02:21:10.760 --> 02:21:12.440] is able to search Wikipedia. [02:21:12.440 --> 02:21:15.080] We have a calculator tool, a calendar tool, [02:21:15.080 --> 02:21:17.840] which has the current day of the week and the date. [02:21:17.840 --> 02:21:20.720] And we have a translation tool, which takes in text [02:21:20.720 --> 02:21:23.280] and puts it back into English. [02:21:23.280 --> 02:21:24.800] And so in choosing these five tools, [02:21:24.800 --> 02:21:27.680] we really wanted a set of tools that is not only diverse, [02:21:27.680 --> 02:21:32.840] but is also going to be likely useful to the language model. [02:21:32.840 --> 02:21:34.800] And what we want to train a model to learn [02:21:34.800 --> 02:21:37.440] is not only which of these five tools to use, [02:21:37.440 --> 02:21:39.280] but when to use that particular tool [02:21:39.280 --> 02:21:41.560] and how to use that tool all on its own [02:21:41.560 --> 02:21:43.480] without human annotation. [02:21:43.480 --> 02:21:46.040] And the way we do this is by taking natural language [02:21:46.040 --> 02:21:48.920] text, like Pittsburgh is known as the steel city, [02:21:48.920 --> 02:21:52.880] and augmenting that text with API or tool calls. [02:21:52.880 --> 02:21:55.400] So for example here, a useful API call [02:21:55.400 --> 02:21:57.680] would be to the QA system with the question, [02:21:57.680 --> 02:21:59.960] what other name is Pittsburgh known by? [02:21:59.960 --> 02:22:02.480] And this is useful because it's useful in anticipating [02:22:02.480 --> 02:22:05.840] the remainder of the text, which is the steel city. [02:22:05.840 --> 02:22:08.960] And we represent an API or tool call with natural language. [02:22:08.960 --> 02:22:12.320] We do square brackets followed by the tool name. [02:22:12.320 --> 02:22:13.680] And then in round parentheses, we [02:22:13.680 --> 02:22:15.680] have the input to that tool followed [02:22:15.680 --> 02:22:18.640] by a right arrow, which is followed [02:22:18.640 --> 02:22:22.640] by the output of the tool with that query. [02:22:22.640 --> 02:22:24.680] And with that, the steps to creating ToolFormer [02:22:24.680 --> 02:22:25.640] is pretty simple. [02:22:25.640 --> 02:22:28.000] In the first step, we want to create a new training [02:22:28.000 --> 02:22:30.320] data set augmented with these API calls [02:22:30.320 --> 02:22:32.720] that I just showed you on the previous slide. [02:22:32.720 --> 02:22:34.840] And in the second step, we want to fine tune [02:22:34.840 --> 02:22:38.160] GPT-J, our base model, on this new data set. [02:22:38.160 --> 02:22:39.940] And this fine-tuned model is the model [02:22:39.940 --> 02:22:43.400] that we refer to as ToolFormer. [02:22:43.400 --> 02:22:45.160] Now to create that training data set, [02:22:45.160 --> 02:22:46.760] we have three simple steps, which I'll [02:22:46.760 --> 02:22:48.440] get into in just a second. [02:22:48.440 --> 02:22:51.600] But first, we want to start out with a standard language [02:22:51.600 --> 02:22:53.840] modeling data set, like CCNET. [02:22:53.840 --> 02:22:55.360] And the reason we want to start here [02:22:55.360 --> 02:22:56.960] is because we don't want to disrupt [02:22:56.960 --> 02:22:59.140] any of the core language modeling capabilities [02:22:59.140 --> 02:23:01.160] that the model may already have. [02:23:01.160 --> 02:23:03.200] And so in using a data set or something [02:23:03.200 --> 02:23:04.960] similar to what it's seen before, [02:23:04.960 --> 02:23:08.480] we minimize this risk as much as possible. [02:23:08.480 --> 02:23:10.200] OK, so let's go into the first step, which [02:23:10.200 --> 02:23:11.880] is to generate API calls. [02:23:11.880 --> 02:23:14.640] And to do so, we show the model a simple prompt. [02:23:14.640 --> 02:23:17.800] We say, your task is to add calls to a question answering [02:23:17.800 --> 02:23:19.800] API to a piece of text. [02:23:19.800 --> 02:23:21.960] The question should help you get information required [02:23:21.960 --> 02:23:23.640] to complete the text. [02:23:23.640 --> 02:23:27.000] And then we explain the format of the API call that we want, [02:23:27.000 --> 02:23:28.960] and we show it a couple of examples. [02:23:28.960 --> 02:23:30.560] I only have one example here, but we [02:23:30.560 --> 02:23:33.240] would put as many examples as can fit into the context [02:23:33.240 --> 02:23:34.720] window. [02:23:34.720 --> 02:23:36.560] And then we show the input that we actually [02:23:36.560 --> 02:23:39.600] want to do inference on, and we let the model generate. [02:23:39.600 --> 02:23:43.000] And here, I'm only showing you the question answering API [02:23:43.000 --> 02:23:44.560] prompt, but you can imagine that we [02:23:44.560 --> 02:23:48.680] do a very similar thing for the rest of the four tools. [02:23:48.680 --> 02:23:53.000] OK, so let's look at a couple of generated API call examples. [02:23:53.000 --> 02:23:56.000] For the input, Pittsburgh is known as the Steel City. [02:23:56.000 --> 02:23:57.840] And here, the model has generated [02:23:57.840 --> 02:24:00.840] in which state is Pittsburgh, what other name is Pittsburgh [02:24:00.840 --> 02:24:04.480] known by, and what is the second city in Pennsylvania. [02:24:04.480 --> 02:24:05.880] And so from these generations, you [02:24:05.880 --> 02:24:09.080] can see that we get a mix of relevant API calls, [02:24:09.080 --> 02:24:11.340] non-relevant API calls, and also some [02:24:11.340 --> 02:24:15.040] that don't make a lot of sense, like the last one. [02:24:15.040 --> 02:24:16.920] And now for the second step, where we actually [02:24:16.920 --> 02:24:19.160] try to execute those API calls. [02:24:19.160 --> 02:24:21.960] So what we do here is we take that natural language string, [02:24:21.960 --> 02:24:24.180] we parse it for the input parameters, [02:24:24.180 --> 02:24:25.960] we send it to the relevant tool, and we [02:24:25.960 --> 02:24:29.000] get an output from the tool. [02:24:29.000 --> 02:24:31.160] Now using those outputs, we want to put them back [02:24:31.160 --> 02:24:33.120] into the embedded API call, and we [02:24:33.120 --> 02:24:36.200] indicate this with a right arrow followed by the output. [02:24:36.200 --> 02:24:38.840] And this is also the step where we would filter out [02:24:38.840 --> 02:24:41.360] generated API calls that are ill-formatted [02:24:41.360 --> 02:24:45.760] or don't actually return a result from the tool. [02:24:45.760 --> 02:24:47.640] And additionally, we also want to filter out [02:24:47.640 --> 02:24:50.520] API calls that aren't actually useful to the model. [02:24:50.520 --> 02:24:52.880] So the way we want to think about usefulness, [02:24:52.880 --> 02:24:55.840] if it's useful for anticipating the remainder of the text, [02:24:55.840 --> 02:24:58.280] as I showed you earlier. [02:24:58.280 --> 02:24:59.920] And the way we quantify usefulness [02:24:59.920 --> 02:25:01.800] is through model-based perplexity. [02:25:01.800 --> 02:25:04.560] And perplexity is basically the negative log likelihood [02:25:04.560 --> 02:25:08.120] of the remainder of the text given the prefix of the text. [02:25:08.120 --> 02:25:12.760] So basically, you want the lowest perplexity possible [02:25:12.760 --> 02:25:15.040] because you want the model to be least perplexed [02:25:15.040 --> 02:25:17.120] about what it's about to see. [02:25:17.120 --> 02:25:21.320] So here, we evaluate perplexity under three different settings. [02:25:21.320 --> 02:25:23.960] The first setting is where we don't have any API call. [02:25:23.960 --> 02:25:28.520] So here, the prefix would just be Pittsburgh is known as. [02:25:28.520 --> 02:25:31.600] In the second setting, we have the non-executed API call, [02:25:31.600 --> 02:25:33.680] where we have the API call, but we're not actually [02:25:33.680 --> 02:25:37.040] going to put the result from the tool yet. [02:25:37.040 --> 02:25:39.840] And then finally, we have the full executed API call, [02:25:39.840 --> 02:25:44.000] where we have the API call and its corresponding output. [02:25:44.000 --> 02:25:46.120] And intuitively, what we want here [02:25:46.120 --> 02:25:48.280] is for the perplexity for setting C [02:25:48.280 --> 02:25:50.800] to be much lower than either A or B [02:25:50.800 --> 02:25:54.080] because not only do we want the generated API call to be [02:25:54.080 --> 02:25:57.160] useful, but we also want the results from the tool [02:25:57.160 --> 02:25:59.360] to be really useful. [02:25:59.360 --> 02:26:02.800] So this is exactly how we evaluate usefulness. [02:26:02.800 --> 02:26:04.760] It's the minimum of the perplexity [02:26:04.760 --> 02:26:09.200] of either under A or B minus the perplexity of C. [02:26:09.200 --> 02:26:12.160] So we want that difference to be as large as possible. [02:26:12.160 --> 02:26:15.280] And here, we have a pretty sizable usefulness score [02:26:15.280 --> 02:26:18.120] of 1.3, which is pretty good. [02:26:18.120 --> 02:26:19.580] But to give you more context, here's [02:26:19.580 --> 02:26:21.520] another example from the calendar tool. [02:26:21.520 --> 02:26:24.520] It says the WL will be open on Friday. [02:26:24.520 --> 02:26:26.320] And the calendar tool tells us that today [02:26:26.320 --> 02:26:27.880] is Thursday, March 9. [02:26:27.880 --> 02:26:29.840] And from this, we can infer that Friday [02:26:29.840 --> 02:26:31.640] is going to be March 10. [02:26:31.640 --> 02:26:36.600] So this gets a high usefulness score of 2.11. [02:26:36.600 --> 02:26:38.960] On the other hand, we have this example from the calculator [02:26:38.960 --> 02:26:39.560] tool. [02:26:39.560 --> 02:26:43.400] The model has seen these two numbers, 85 patients and 23%, [02:26:43.400 --> 02:26:45.800] and it thinks maybe the ratio is going to be useful. [02:26:45.800 --> 02:26:47.440] But unfortunately, that's not the case, [02:26:47.440 --> 02:26:51.360] and it gets a low usefulness score of negative 0.02. [02:26:51.360 --> 02:26:56.480] So this would likely be filtered out in our final third step. [02:26:56.480 --> 02:26:59.000] Now, here I'm showing you the number of examples that remain [02:26:59.000 --> 02:27:00.480] after this filtering process. [02:27:00.480 --> 02:27:02.040] For two different kinds of thresholds, [02:27:02.040 --> 02:27:06.120] we have in light blue 0.5, and in dark blue, we have 1.0. [02:27:06.120 --> 02:27:08.360] And obviously, you're going to get a lot more examples [02:27:08.360 --> 02:27:12.120] left over if you use a less stringent threshold, 0.5. [02:27:12.120 --> 02:27:13.800] But the other thing you can see here [02:27:13.800 --> 02:27:15.600] is that we have the most number of examples [02:27:15.600 --> 02:27:18.360] from the Wikipedia search tool, whereas for calculator [02:27:18.360 --> 02:27:22.240] and machine translation, we have the fewest number of examples. [02:27:22.240 --> 02:27:24.600] And now what we do here is we cap the number of examples [02:27:24.600 --> 02:27:27.760] per tool at 25,000, and we put it all together [02:27:27.760 --> 02:27:30.480] in one big data set. [02:27:30.480 --> 02:27:33.840] And with that data set, we fine tune our base model, GPT-J. [02:27:33.840 --> 02:27:39.040] And this fine tune model is what we refer to as tool former. [02:27:39.040 --> 02:27:40.760] Now, to evaluate tool former, we want [02:27:40.760 --> 02:27:42.800] to evaluate on a range of tasks where [02:27:42.800 --> 02:27:45.640] we think at least one of the tool is going to be useful. [02:27:45.640 --> 02:27:48.240] So we have fact completion and question answering. [02:27:48.240 --> 02:27:51.200] We also have math computations and multilingual questions [02:27:51.200 --> 02:27:53.320] where the context is given in English, [02:27:53.320 --> 02:27:56.440] but the question can be in a different language. [02:27:56.440 --> 02:27:58.240] And we also have temporal questions, [02:27:58.240 --> 02:28:00.520] like how many days is it until Christmas, [02:28:00.520 --> 02:28:02.800] where you need to know the current time or date in order [02:28:02.800 --> 02:28:05.720] to answer the question. [02:28:05.720 --> 02:28:08.080] Now here are the results for those five tasks. [02:28:08.080 --> 02:28:09.360] We have three different models. [02:28:09.360 --> 02:28:12.640] We have GPT-J, which is the base model, tool former, [02:28:12.640 --> 02:28:17.000] and GPT-3, which is a 175 billion parameter model. [02:28:17.000 --> 02:28:19.560] And what you can see here is that in almost all cases, [02:28:19.560 --> 02:28:21.920] tool former is outperforming GPT-J, [02:28:21.920 --> 02:28:23.960] but it's also outperforming GPT-3, [02:28:23.960 --> 02:28:27.200] even though it's about 30 times smaller than GPT-3. [02:28:27.200 --> 02:28:30.260] And an exception to this is the question answering task, [02:28:30.260 --> 02:28:32.720] where we actually disabled the QA system. [02:28:32.720 --> 02:28:34.560] And this is because there's a lot of overlap [02:28:34.560 --> 02:28:37.600] in the training set of the QA system and our evaluation tasks. [02:28:37.600 --> 02:28:39.800] So we thought this would be too much of an advantage [02:28:39.800 --> 02:28:41.920] if we enabled that tool. [02:28:41.920 --> 02:28:44.760] The second anomaly is the multilingual task, [02:28:44.760 --> 02:28:47.160] where we don't see a lot of benefit from the translation [02:28:47.160 --> 02:28:47.660] tool. [02:28:47.660 --> 02:28:50.360] And we think this is likely because GPT-J has already [02:28:50.360 --> 02:28:52.520] seen a lot of multilingual text and isn't [02:28:52.520 --> 02:28:56.320] getting a whole lot of benefit from actually using that tool. [02:28:56.320 --> 02:28:58.080] But regardless, we see that tool former [02:28:58.080 --> 02:29:03.480] is either on par with GPT-J or outperforming GPT-J. [02:29:03.480 --> 02:29:05.560] And the second thing that we want to look at [02:29:05.560 --> 02:29:08.720] is whether or not small models can effectively use tools. [02:29:08.720 --> 02:29:11.360] So in other words, is there a minimum size requirement [02:29:11.360 --> 02:29:15.600] with which models are able to effectively use tools? [02:29:15.600 --> 02:29:18.400] So to investigate this, we applied the same kind [02:29:18.400 --> 02:29:21.480] of pipeline to the family of GPT-2 models. [02:29:21.480 --> 02:29:22.640] So there are four of them. [02:29:22.640 --> 02:29:24.480] And in total, we have five different models [02:29:24.480 --> 02:29:27.880] at various sizes, which I'm showing you on the x-axis. [02:29:27.880 --> 02:29:30.920] And on the y-axis, we have model performance. [02:29:30.920 --> 02:29:32.880] And in blue, we have tool former. [02:29:32.880 --> 02:29:34.960] And in red, we have tool former disabled, [02:29:34.960 --> 02:29:39.280] where we use constraint decoding to prevent the usage of tools. [02:29:39.280 --> 02:29:41.280] And as you can see, in the smallest two sizes, [02:29:41.280 --> 02:29:43.960] we don't see any performance difference between tool former [02:29:43.960 --> 02:29:45.720] and tool former disabled, meaning [02:29:45.720 --> 02:29:48.280] that tool former is not able to make use of those five [02:29:48.280 --> 02:29:50.280] tools to its fullest. [02:29:50.280 --> 02:29:53.520] But once we get to 775 million parameters, [02:29:53.520 --> 02:29:55.240] we see a performance gap emerging. [02:29:55.240 --> 02:29:56.640] And this gets bigger and sustained [02:29:56.640 --> 02:29:59.280] for the rest of the sizes. [02:29:59.280 --> 02:30:02.160] And this is a similar thing that we see with the math benchmarks. [02:30:02.160 --> 02:30:04.400] It seems that tool usage is really emerging [02:30:04.400 --> 02:30:07.120] at 775 million parameters. [02:30:07.120 --> 02:30:09.080] For the question answering benchmarks, [02:30:09.080 --> 02:30:10.400] we don't see this as clearly. [02:30:10.400 --> 02:30:12.080] And we think that maybe this is likely [02:30:12.080 --> 02:30:14.960] because the QA system and the Wikipedia search tool [02:30:14.960 --> 02:30:16.920] are easier tools to use. [02:30:16.920 --> 02:30:18.760] And so you don't need a more capable model [02:30:18.760 --> 02:30:22.800] to be able to understand how to use it effectively. [02:30:22.800 --> 02:30:24.520] And finally, we also want to revisit [02:30:24.520 --> 02:30:26.240] the question of whether or not tool former [02:30:26.240 --> 02:30:27.760] is a good language model. [02:30:27.760 --> 02:30:30.040] We originally used a data set CCNet [02:30:30.040 --> 02:30:31.760] because we didn't want to disrupt [02:30:31.760 --> 02:30:34.300] any of the core language modeling capabilities. [02:30:34.300 --> 02:30:35.880] And so now we revisit that question [02:30:35.880 --> 02:30:38.720] by looking at perplexity on a held-out set of Wikitext [02:30:38.720 --> 02:30:40.120] and CCNet. [02:30:40.120 --> 02:30:43.080] And here we have three different models, GPT-J, [02:30:43.080 --> 02:30:46.000] GPT-J further fine-tuned on CCNet, [02:30:46.000 --> 02:30:49.440] and tool former, which is further fine-tuned on CCNet, [02:30:49.440 --> 02:30:52.320] augmented with those API calls. [02:30:52.320 --> 02:30:54.440] And what we find is that the perplexity is pretty much [02:30:54.440 --> 02:30:57.480] on par with the base model and the further fine-tuned one. [02:30:57.480 --> 02:30:59.320] We don't see a whole lot of difference. [02:30:59.320 --> 02:31:01.320] And so we feel pretty encouraged that even [02:31:01.320 --> 02:31:04.280] though this data set may look a bit unnatural with these API [02:31:04.280 --> 02:31:07.000] calls, it doesn't actually harm the core language modeling [02:31:07.000 --> 02:31:10.200] capabilities here. [02:31:10.200 --> 02:31:12.560] So thank you for listening to this talk. [02:31:12.560 --> 02:31:14.560] Please check out our paper at this QR code. [02:31:14.560 --> 02:31:16.880] We have a poster in the next poster session. [02:31:16.880 --> 02:31:19.040] We are poster number 332. [02:31:19.040 --> 02:31:21.480] And I will be there with Roberta. [02:31:21.480 --> 02:31:24.520] Please feel free to reach out to any of the co-authors and me. [02:31:24.520 --> 02:31:27.520] I'm happy to take questions now or later. [02:31:27.520 --> 02:31:28.520] Thanks. [02:31:28.520 --> 02:31:30.280] [APPLAUSE] [02:31:30.280 --> 02:31:32.920] When I look at all the relevant papers for AI engineers [02:31:32.920 --> 02:31:35.720] this year, there's the chain-of-thought papers [02:31:35.720 --> 02:31:39.360] and the tool-use papers, two of which we just covered. [02:31:39.360 --> 02:31:41.960] But something that I think incorporates all of them [02:31:41.960 --> 02:31:47.360] and then adds a few ideas that are unique and notable to them [02:31:47.360 --> 02:31:51.040] is the Voyager paper from NVIDIA. [02:31:51.040 --> 02:31:53.640] And even though it was released in the first half of the year, [02:31:53.640 --> 02:31:55.320] people are still talking about it today. [02:31:55.320 --> 02:31:57.880] It's still shaping people's mental perceptions [02:31:57.880 --> 02:32:01.640] of how they want to build their LLM architectures. [02:32:01.640 --> 02:32:05.640] It was somehow not accepted for posters or oral sessions [02:32:05.640 --> 02:32:07.080] at this year's NeurIPS. [02:32:07.080 --> 02:32:09.080] It's a kind of a mystery as to why. [02:32:09.080 --> 02:32:10.480] I did chat with Jim. [02:32:10.480 --> 02:32:13.160] And I'm still not really sure what's going on there. [02:32:13.160 --> 02:32:15.600] But it would have been my vote for best paper [02:32:15.600 --> 02:32:18.520] because it's so foundational and established [02:32:18.520 --> 02:32:21.120] such a strong baseline for everyone else to build [02:32:21.120 --> 02:32:22.320] on top of LLMs. [02:32:22.320 --> 02:32:25.640] And anyway, so there is some workshops, presentations [02:32:25.640 --> 02:32:27.800] about Voyager with the first author. [02:32:27.800 --> 02:32:28.800] So here it is. [02:32:28.800 --> 02:32:31.760] [APPLAUSE] [02:32:31.760 --> 02:32:33.640] My name is Guanzhi Wang. [02:32:33.640 --> 02:32:36.520] Currently, I'm a third-year PhD student at Caltech. [02:32:36.520 --> 02:32:38.760] I'm also a research intern at NVIDIA. [02:32:38.760 --> 02:32:40.600] I'm very happy to present Voyager, [02:32:40.600 --> 02:32:44.920] an open-ended, embodied agent with large-language models. [02:32:44.920 --> 02:32:47.920] This year, GPT-4 came, a large-language model [02:32:47.920 --> 02:32:51.280] that's so good at coding and long-horizon planning. [02:32:51.280 --> 02:32:54.520] So we built Voyager, the first large-language model-powered [02:32:54.520 --> 02:32:56.120] left-arm learning agent. [02:32:56.120 --> 02:32:58.420] When we set Voyager loose in Minecraft, [02:32:58.420 --> 02:33:00.960] it is able to play the game for hours on end [02:33:00.960 --> 02:33:03.780] without any human intervention. [02:33:03.780 --> 02:33:06.560] The video here shows snippets from a single episode [02:33:06.560 --> 02:33:07.640] of Voyager. [02:33:07.640 --> 02:33:12.160] So it explores the terrains, mines all kinds of materials, [02:33:12.160 --> 02:33:15.480] fight monsters, craft hundreds of recipes, [02:33:15.480 --> 02:33:18.920] and unlocks an ever-expanding trade of skills. [02:33:18.920 --> 02:33:21.680] If you want to use the full power of GPT-4, [02:33:21.680 --> 02:33:25.360] a central question is, how do we stringify things? [02:33:25.360 --> 02:33:27.400] In other words, how do we convert [02:33:27.400 --> 02:33:30.720] this embodied environment with multi-modal observation [02:33:30.720 --> 02:33:33.060] and action space into pure text? [02:33:33.060 --> 02:33:34.560] We need a magic box. [02:33:34.560 --> 02:33:37.320] And thankfully, the enthusiastic Minecraft community [02:33:37.320 --> 02:33:38.560] already built one. [02:33:38.560 --> 02:33:40.920] It's called Mineflayer, a high-level JavaScript [02:33:40.920 --> 02:33:43.040] API that's actively maintained to work [02:33:43.040 --> 02:33:45.000] with every Minecraft version. [02:33:45.000 --> 02:33:46.780] The beauty of Mineflayer is that it [02:33:46.780 --> 02:33:49.160] has access to the game state surrounding [02:33:49.160 --> 02:33:53.860] the agent, like the nearby blocks, animals, and enemies. [02:33:53.860 --> 02:33:57.040] So we effectively have a ground-truth perception module [02:33:57.040 --> 02:33:58.360] as a textual channel. [02:33:58.360 --> 02:34:00.580] Now that we convert everything to text, [02:34:00.580 --> 02:34:04.320] we are ready to construct an agent algorithm on top of GPT-4. [02:34:04.320 --> 02:34:07.200] And on the high level, there are three components. [02:34:07.200 --> 02:34:09.760] First, a coding module that writes JavaScript [02:34:09.760 --> 02:34:11.520] to control the game bot. [02:34:11.520 --> 02:34:15.780] It's the main module that generates executable actions. [02:34:15.780 --> 02:34:19.200] Second, we have a code base to store the correctly written [02:34:19.200 --> 02:34:21.140] code and look it up in the future [02:34:21.140 --> 02:34:23.240] if the agents need to recall the skill. [02:34:23.240 --> 02:34:25.980] In this way, we don't duplicate coding efforts [02:34:25.980 --> 02:34:30.960] and achieve a form of learning without grading the set. [02:34:30.960 --> 02:34:33.080] Third, we have a curriculum that proposes [02:34:33.080 --> 02:34:36.760] what to do next, given the agent's current capabilities. [02:34:36.760 --> 02:34:38.880] So we'll wire them up together. [02:34:38.880 --> 02:34:41.760] We get a loop that drives the agents indefinitely [02:34:41.760 --> 02:34:44.480] and achieve something like lifelong learning. [02:34:44.480 --> 02:34:47.040] So let's zoom in the center module. [02:34:47.040 --> 02:34:50.240] We prompt GPT-4 with documentations and examples [02:34:50.240 --> 02:34:53.480] on how to use a subset of the Mineflayer API. [02:34:53.480 --> 02:34:56.240] Then GPT-4 writes code to take actions [02:34:56.240 --> 02:34:58.560] given the current assigned task. [02:34:58.560 --> 02:35:01.720] And because JavaScript runs a code interpreter, [02:35:01.720 --> 02:35:04.400] GPT-4 can define new functions on the fly [02:35:04.400 --> 02:35:06.280] and run it interactively. [02:35:06.280 --> 02:35:08.840] But the code that GPT-4 writes isn't always [02:35:08.840 --> 02:35:11.240] able to get it right at the first try. [02:35:11.240 --> 02:35:13.440] We develop an iterative prompting mechanism [02:35:13.440 --> 02:35:15.120] to refine the program. [02:35:15.120 --> 02:35:17.920] There are three types of feedback. [02:35:17.920 --> 02:35:21.760] First, the environment feedback, like what new materials did [02:35:21.760 --> 02:35:24.200] you get after taking an action? [02:35:24.200 --> 02:35:27.400] Second, the execution error from JavaScript interpreter, [02:35:27.400 --> 02:35:31.640] like variable undefined error. [02:35:31.640 --> 02:35:33.760] And we have another GPT-4 that provides [02:35:33.760 --> 02:35:35.320] critiques through self-reflection [02:35:35.320 --> 02:35:37.200] from the agent's own states. [02:35:37.200 --> 02:35:40.080] So these components help the agent refine the program [02:35:40.080 --> 02:35:43.560] effectively. [02:35:43.560 --> 02:35:45.840] I want to show some examples of how the critique [02:35:45.840 --> 02:35:50.840] module provides feedback on the task completion progress. [02:35:50.840 --> 02:35:54.200] In the first example, the task is to craft a spyglass. [02:35:54.200 --> 02:35:56.800] So GPT-4 looks at the agent's inventory [02:35:56.800 --> 02:35:59.200] and decides that it has enough copper, [02:35:59.200 --> 02:36:03.040] but not enough amethyst. [02:36:03.040 --> 02:36:06.160] Second task is to kill three sheep to collect food. [02:36:06.160 --> 02:36:09.000] So each sheep drops one unit of white wool, [02:36:09.000 --> 02:36:11.640] but there are only two units in the inventory. [02:36:11.640 --> 02:36:13.040] So one more sheep to go. [02:36:13.040 --> 02:36:18.800] Last example, killing a zombie drops a unit [02:36:18.800 --> 02:36:21.120] of rotten flesh, which is in the inventory. [02:36:21.120 --> 02:36:24.400] So GPT-4 determines that the task is successful [02:36:24.400 --> 02:36:25.800] and moves on. [02:36:25.800 --> 02:36:27.840] So this critique procedure is repeated [02:36:27.840 --> 02:36:30.520] until the task is deemed successful [02:36:30.520 --> 02:36:31.760] or hits the time limit. [02:36:31.760 --> 02:36:39.000] Now, moving on to the second part. [02:36:39.000 --> 02:36:41.640] Once it implements a skill correctly, [02:36:41.640 --> 02:36:44.200] we save it to our persistent storage. [02:36:44.200 --> 02:36:46.440] So think of it as a skill library [02:36:46.440 --> 02:36:51.120] that's authored purely by GPT-4 through trial and error. [02:36:51.120 --> 02:36:54.040] Then the agent can retrieve the skills from the library [02:36:54.040 --> 02:36:56.400] when facing similar situations in the future. [02:36:56.400 --> 02:36:58.320] So it doesn't need to write them again. [02:36:58.320 --> 02:37:00.480] In this way, Voyager improves itself [02:37:00.480 --> 02:37:03.000] as it experiences more and more in Minecraft. [02:37:03.000 --> 02:37:08.480] Let's dive a bit deeper into how the skill library is [02:37:08.480 --> 02:37:09.800] implemented. [02:37:09.800 --> 02:37:12.320] So this is how we insert a new skill. [02:37:12.320 --> 02:37:15.640] First, we use GPT-3.5 to summarize the program [02:37:15.640 --> 02:37:17.000] into plain English. [02:37:17.000 --> 02:37:19.840] So summarization is very easy and doesn't need GPT-4. [02:37:19.840 --> 02:37:22.480] So we save some money here. [02:37:22.480 --> 02:37:25.800] Then the embedding of the summary becomes a key, [02:37:25.800 --> 02:37:28.040] and the program becomes a value, which [02:37:28.040 --> 02:37:32.000] we insert into a vector database. [02:37:32.000 --> 02:37:35.080] We find it better to embed the description instead [02:37:35.080 --> 02:37:38.480] of the raw program, because it's more semantic [02:37:38.480 --> 02:37:40.760] and improves the retrieval. [02:37:40.760 --> 02:37:47.640] Now, when Voyager is faced with a new task-- [02:37:47.640 --> 02:37:50.080] let's say, craft iron pickaxe-- [02:37:50.080 --> 02:37:54.440] we use GPT-3.5 to generate a hint on how to solve the task [02:37:54.440 --> 02:37:58.040] and combine it with world state as security content. [02:37:58.040 --> 02:38:03.440] Then we do the embedding and retrieve the top five [02:38:03.440 --> 02:38:05.240] relevant skills from the skill library. [02:38:05.240 --> 02:38:11.200] So Voyager is free to directly use one of the skills as is, [02:38:11.200 --> 02:38:16.240] or interpolate among the five, or rewrite one from scratch. [02:38:16.240 --> 02:38:19.920] In this way, we maximally reuse the old experiences. [02:38:19.920 --> 02:38:22.840] Think of it as an in-context replay buffer [02:38:22.840 --> 02:38:24.920] in the reinforcement learning terminology. [02:38:24.920 --> 02:38:31.560] Now, moving on to the third part. [02:38:31.560 --> 02:38:34.600] We have yet another GPT-4 that proposes [02:38:34.600 --> 02:38:38.160] what task to do, given its own capability at the moment. [02:38:38.160 --> 02:38:40.680] The curriculum has an unsupervised objective, [02:38:40.680 --> 02:38:43.720] which is to maximize the number of novel items [02:38:43.720 --> 02:38:45.520] that the agent obtains. [02:38:45.520 --> 02:38:47.880] There are two key insights here. [02:38:47.880 --> 02:38:51.000] First, it's kind of curiosity-driven exploration, [02:38:51.000 --> 02:38:53.640] or novelty search in prior literature, [02:38:53.640 --> 02:38:55.640] but implemented purely in context. [02:38:55.640 --> 02:38:59.120] Oh, sorry. [02:38:59.120 --> 02:39:01.440] And second, it's a situation where [02:39:01.440 --> 02:39:05.000] a curriculum that naturally gets progressively harder over time, [02:39:05.000 --> 02:39:09.400] all without any manual prescription from us. [02:39:09.400 --> 02:39:12.360] So let's go through a working example together. [02:39:12.360 --> 02:39:16.280] The agent finds its hunger bar dropping to 1 out of 20, [02:39:16.280 --> 02:39:18.200] so it needs to find food. [02:39:18.200 --> 02:39:21.240] Now, it senses four entities nearby-- [02:39:21.240 --> 02:39:25.080] a cat, a villager, a pig, and some wheat seed. [02:39:25.080 --> 02:39:27.280] So it starts an inner monologue. [02:39:27.280 --> 02:39:29.240] Do I kill the cat or villager? [02:39:29.240 --> 02:39:30.280] Bad idea. [02:39:30.280 --> 02:39:31.560] How about the wheat seed? [02:39:31.560 --> 02:39:34.840] I can grow a farm, but it's going to take a long time. [02:39:34.840 --> 02:39:37.280] So sorry, piggy, you are the chosen one. [02:39:37.280 --> 02:39:41.600] It checks the inventory and retrieves an old skill [02:39:41.600 --> 02:39:43.800] from the library to craft an iron sword, [02:39:43.800 --> 02:39:47.080] and then starts to learn a new skill called hunt pig. [02:39:47.080 --> 02:39:49.880] Now, we also know that voyager isn't a vegetarian, [02:39:49.880 --> 02:39:53.400] unfortunately. [02:39:53.400 --> 02:39:55.440] So putting our pieces together, we [02:39:55.440 --> 02:39:57.160] have an iterative prompting mechanism [02:39:57.160 --> 02:40:01.040] that refunds a program by a self-debugging, [02:40:01.040 --> 02:40:06.320] a skill library as an in-context replay buffer, [02:40:06.320 --> 02:40:09.680] and an automatic curriculum as in-context curiosity-driven [02:40:09.680 --> 02:40:11.240] exploration. [02:40:11.240 --> 02:40:13.640] This is voyager's no-gradient architecture, [02:40:13.640 --> 02:40:15.560] where we don't train any new model [02:40:15.560 --> 02:40:17.160] or train any parameters. [02:40:17.160 --> 02:40:20.120] It allows voyager to self-bootstrap and perform [02:40:20.120 --> 02:40:21.880] lifelong learning in an open-ended world. [02:40:21.880 --> 02:40:27.720] So these are the tasks that voyager happens [02:40:27.720 --> 02:40:29.480] to pick up along the way. [02:40:29.480 --> 02:40:31.560] We didn't pre-program any of this. [02:40:31.560 --> 02:40:33.440] It's all voyager's idea. [02:40:33.440 --> 02:40:35.880] The agent is forever curious and forever [02:40:35.880 --> 02:40:37.280] pursuing new adventures. [02:40:37.280 --> 02:40:43.240] We've done a lot of systematic study for voyager, [02:40:43.240 --> 02:40:45.560] and here is the quantitative learning curve. [02:40:45.560 --> 02:40:48.720] Well, the x-axis is the number of prompting iterations, [02:40:48.720 --> 02:40:51.660] and the y-axis is the number of unique items [02:40:51.660 --> 02:40:53.280] obtained by each agent. [02:40:53.280 --> 02:40:55.880] We compare with three baselines-- [02:40:55.880 --> 02:40:58.880] React, Reflexing, and Auto-GPT. [02:40:58.880 --> 02:41:00.200] All of these are no-gradient-- [02:41:00.200 --> 02:41:08.840] all of these are no-gradient architecture on top of GPT-4. [02:41:08.840 --> 02:41:11.800] React is a very simple reasoning and acting loop, [02:41:11.800 --> 02:41:15.200] and Reflexing is built on top of React with self-reflection. [02:41:15.200 --> 02:41:17.600] We see that both struggle to make progress [02:41:17.600 --> 02:41:21.420] beyond the basic wooden tools. [02:41:21.420 --> 02:41:24.000] And Auto-GPT is a popular software repo. [02:41:24.000 --> 02:41:27.480] It combines React and a task planner [02:41:27.480 --> 02:41:30.480] that decompose an objective into sub-goals. [02:41:30.480 --> 02:41:33.720] It makes more progress, but it's very slow. [02:41:33.720 --> 02:41:34.720] And this is voyager. [02:41:34.720 --> 02:41:37.140] We are able to obtain three times more novel items [02:41:37.140 --> 02:41:39.840] than the prior method and unlock the whole tech tree [02:41:39.840 --> 02:41:46.840] significantly faster from wooden to stone to iron to diamond. [02:41:46.840 --> 02:41:48.480] The blue curve here is an application [02:41:48.480 --> 02:41:51.560] without skill library, which plateaus after a while. [02:41:51.560 --> 02:41:53.120] So basically, the skill library is [02:41:53.120 --> 02:41:56.480] very essential for voyagers' lifelong learning capabilities. [02:41:56.480 --> 02:42:02.400] Here are two precise views of Minecraft maps. [02:42:02.400 --> 02:42:04.720] So these circles are what the prior method [02:42:04.720 --> 02:42:07.600] explore, given the same prompting iteration budget. [02:42:07.600 --> 02:42:11.320] You can see that they tend to get stuck in local areas. [02:42:11.320 --> 02:42:13.960] Voyager is able to navigate distance two times longer [02:42:13.960 --> 02:42:15.600] compared to prior works. [02:42:15.600 --> 02:42:17.840] It has to visit more diverse terrains in order [02:42:17.840 --> 02:42:20.960] to find more novel items quickly. [02:42:20.960 --> 02:42:23.820] Finally, one limitation is that voyager does not currently [02:42:23.820 --> 02:42:27.200] support visual perception, because GPT-4 is text-only [02:42:27.200 --> 02:42:29.440] when we were developing voyager, but there's [02:42:29.440 --> 02:42:31.880] nothing stopping voyager from using a multimodal model [02:42:31.880 --> 02:42:34.600] to achieve more impressive tasks. [02:42:34.600 --> 02:42:36.920] And here we demonstrate that, given human feedback, [02:42:36.920 --> 02:42:39.280] voyager is able to construct complex 3D structures [02:42:39.280 --> 02:42:42.960] in Minecraft, such as a house and a nether portal. [02:42:42.960 --> 02:42:45.600] We basically use the human to replace the critic module [02:42:45.600 --> 02:42:49.000] of voyager and provide 3D spatial advice. [02:42:49.000 --> 02:42:51.280] So to build very complex structures, [02:42:51.280 --> 02:42:55.440] we definitely need some full-blown multimodal models, [02:42:55.440 --> 02:42:58.280] and I will leave that to future works. [02:42:58.280 --> 02:43:01.480] This is voyager's website at voyager.mandojo.org. [02:43:01.480 --> 02:43:03.880] We open source everything, including the environment, [02:43:03.880 --> 02:43:08.200] algorithm, prompts, and pre-trained skill libraries. [02:43:08.200 --> 02:43:10.440] Finally, I want to acknowledge all the team members [02:43:10.440 --> 02:43:11.280] of voyager. [02:43:11.280 --> 02:43:13.560] This work would not be possible without their help. [02:43:13.560 --> 02:43:16.440] So please feel free to reach out if you have any questions. [02:43:16.440 --> 02:43:16.940] Thanks. [02:43:16.940 --> 02:43:19.840] [APPLAUSE] [02:43:19.840 --> 02:43:22.800] I think the last component of agents, [02:43:22.800 --> 02:43:26.120] apart from chain of thought and tool use [02:43:26.120 --> 02:43:28.080] that I wrote up in the Anatomy of Autonomy [02:43:28.080 --> 02:43:32.520] write-up in April, is the need for better planning. [02:43:32.520 --> 02:43:36.880] And I think one of the most interesting or challenging [02:43:36.880 --> 02:43:39.480] pieces, depending how you look at it, of NeurIPS [02:43:39.480 --> 02:43:43.080] is doing poster diving, where instead [02:43:43.080 --> 02:43:45.000] of going to all the oral sessions, which [02:43:45.000 --> 02:43:48.320] have been curated by track committees and all that, [02:43:48.320 --> 02:43:51.600] you just go and walk the halls and look for posters [02:43:51.600 --> 02:43:52.620] and look for papers. [02:43:52.620 --> 02:43:55.960] And people that are underrated have been overlooked. [02:43:55.960 --> 02:43:58.680] And in fact, the original "Attention is All You Need" [02:43:58.680 --> 02:44:01.000] Transformers paper was one such paper, [02:44:01.000 --> 02:44:04.080] where they were just the poster-only paper, apparently. [02:44:04.080 --> 02:44:06.400] From walking the halls in the poster sessions, [02:44:06.400 --> 02:44:09.720] my pick for underrated paper was Ida Mumenijad [02:44:09.720 --> 02:44:12.520] from Microsoft Research with COG eval. [02:44:12.520 --> 02:44:14.640] Ida was very confident and professorial [02:44:14.640 --> 02:44:17.880] in her presentation, made it engaging, made it a quiz. [02:44:17.880 --> 02:44:19.360] Some parts of the quiz are visual. [02:44:19.360 --> 02:44:20.840] So if you're listening along and you [02:44:20.840 --> 02:44:22.840] want to solve it alongside us, you [02:44:22.840 --> 02:44:24.920] should probably pull up the show notes [02:44:24.920 --> 02:44:27.720] and check out the graphs that I'm going to paste inside [02:44:27.720 --> 02:44:29.320] of the show notes. [02:44:29.320 --> 02:44:31.320] But otherwise, she just made it very engaging [02:44:31.320 --> 02:44:33.040] for people to follow along. [02:44:33.040 --> 02:44:33.740] I'm not kidding. [02:44:33.740 --> 02:44:35.480] There was a group of 10, 20 of us [02:44:35.480 --> 02:44:37.240] way back in the halls in the poster sessions [02:44:37.240 --> 02:44:39.820] where a lot of people don't really end up going. [02:44:39.820 --> 02:44:41.980] And we were just half an hour while she [02:44:41.980 --> 02:44:45.800] was giving her impromptu lecture about COG eval. [02:44:45.800 --> 02:44:49.360] And I do think that this is notable because it is [02:44:49.360 --> 02:44:52.440] potentially a quantifiable benchmark for reasoning [02:44:52.440 --> 02:44:55.860] and planning capabilities that currently all the language [02:44:55.860 --> 02:44:57.920] models don't do very well. [02:44:57.920 --> 02:44:59.760] And framing it as a graph problem [02:44:59.760 --> 02:45:02.320] helps us generalize to all sorts of reasoning, planning, [02:45:02.320 --> 02:45:04.040] and search situations. [02:45:04.040 --> 02:45:06.340] And I just like that it was really well presented. [02:45:06.340 --> 02:45:07.800] This is obviously a benchmark paper, [02:45:07.800 --> 02:45:09.480] so there's no solutions proposed. [02:45:09.480 --> 02:45:11.920] But she has another paper that she's working on [02:45:11.920 --> 02:45:13.680] that has some of her solutions. [02:45:13.680 --> 02:45:15.720] So LLMs are ubiquitous. [02:45:15.720 --> 02:45:17.980] And a lot of people claim that they can plan [02:45:17.980 --> 02:45:19.940] or they're going to plan to take over the world. [02:45:19.940 --> 02:45:22.620] But first things first, can they actually plan? [02:45:22.620 --> 02:45:24.480] I have 15 years of experience working [02:45:24.480 --> 02:45:26.440] in reinforcement learning and cognitive science [02:45:26.440 --> 02:45:30.320] and in neuroscience, evaluating planning in humans and brains [02:45:30.320 --> 02:45:32.000] and reinforcement learning models. [02:45:32.000 --> 02:45:34.080] So I thought, OK, let's apply that. [02:45:34.080 --> 02:45:38.120] In order to accurately evaluate whether a cognitive capacity [02:45:38.120 --> 02:45:41.200] exists in an agent or in a biological system, [02:45:41.200 --> 02:45:43.900] there needs to be a systematic protocol to evaluate it. [02:45:43.900 --> 02:45:46.480] Inspired by cognitive science, we have two contributions here. [02:45:46.480 --> 02:45:49.660] First, we introduced COG eval, a systematic protocol [02:45:49.660 --> 02:45:54.000] for evaluating cognitive capacities. [02:45:54.000 --> 02:45:56.240] What that means is you need to operationalize [02:45:56.240 --> 02:45:59.040] a particular latent ability in terms of multiple tasks that [02:45:59.040 --> 02:46:00.160] can be measured. [02:46:00.160 --> 02:46:03.960] And these measurements need to unconfound or decouple [02:46:03.960 --> 02:46:06.960] certain confounds from what is actually [02:46:06.960 --> 02:46:09.480] being measured in terms of that cognitive ability. [02:46:09.480 --> 02:46:12.160] So for instance, if you give it some simple situations, [02:46:12.160 --> 02:46:14.720] it might be that it solves it, but you can't declare victory [02:46:14.720 --> 02:46:18.240] unless you show that the tasks that you have created somehow [02:46:18.240 --> 02:46:22.360] capture different aspects of the cognitive latent ability [02:46:22.360 --> 02:46:23.720] that you are measuring. [02:46:23.720 --> 02:46:26.160] Second, you want to operationalize it [02:46:26.160 --> 02:46:29.240] in terms of different structures, different domains, [02:46:29.240 --> 02:46:31.180] and different tasks. [02:46:31.180 --> 02:46:32.680] You don't want to measure one or two [02:46:32.680 --> 02:46:34.520] things in one or two environments [02:46:34.520 --> 02:46:36.900] and with an anecdote declare that something works [02:46:36.900 --> 02:46:38.220] or something exists. [02:46:38.220 --> 02:46:40.640] So here, what, for instance, you have [02:46:40.640 --> 02:46:42.160] is different graph structures. [02:46:42.160 --> 02:46:45.400] I have six structures that I'll show you, different domain. [02:46:45.400 --> 02:46:47.600] I'll show you the spatial domain, for instance. [02:46:47.600 --> 02:46:49.720] If I ask you for planning, I could ask you, [02:46:49.720 --> 02:46:51.240] how do you go to Hull Seafronter? [02:46:51.240 --> 02:46:53.120] Or I could give you an information [02:46:53.120 --> 02:46:54.800] about Ali is friends with Michael. [02:46:54.800 --> 02:46:57.000] Michael is friends with Mary, Mary's friends with Sue. [02:46:57.000 --> 02:46:58.720] If Ali wants to pass a message to Sue, [02:46:58.720 --> 02:47:00.480] what is the path, for instance? [02:47:00.480 --> 02:47:02.640] That's the planning in the social domain. [02:47:02.640 --> 02:47:05.360] So social and spatial domain, different domains, and also [02:47:05.360 --> 02:47:06.120] task conditions. [02:47:06.120 --> 02:47:07.840] We use 15 different tasks. [02:47:07.840 --> 02:47:09.320] These are inspired by various tasks [02:47:09.320 --> 02:47:11.040] that I have designed in the past. [02:47:11.040 --> 02:47:13.360] You can look at these two papers and others. [02:47:13.360 --> 02:47:15.720] This goes back 100 years ago to the tradition started [02:47:15.720 --> 02:47:17.800] by Edward Tolman on cognitive maps [02:47:17.800 --> 02:47:20.200] in rats and men, 1948 review paper, [02:47:20.200 --> 02:47:24.120] reviews 20 years of research, where it shows behaviorally [02:47:24.120 --> 02:47:26.720] how to measure whether an entity-- in that case, [02:47:26.720 --> 02:47:28.440] he was measuring rats-- [02:47:28.440 --> 02:47:30.600] possess a cognitive map. [02:47:30.600 --> 02:47:32.760] It was a revolutionary result at the time, [02:47:32.760 --> 02:47:35.960] because it went against the behavioral stigma of the time [02:47:35.960 --> 02:47:37.720] that you need a reward to learn structures. [02:47:37.720 --> 02:47:40.160] It showed that no rats can learn the cognitive map [02:47:40.160 --> 02:47:43.280] of the environment, even if you don't give them rewards. [02:47:43.280 --> 02:47:46.200] OK, come back to present day, 15 tasks [02:47:46.200 --> 02:47:48.120] in five different categories. [02:47:48.120 --> 02:47:50.680] The goal is to evaluate systematically [02:47:50.680 --> 02:47:53.120] whether LLMs can extract from descriptions [02:47:53.120 --> 02:47:55.000] of an environment the cognitive map. [02:47:55.000 --> 02:47:56.000] And what does that mean? [02:47:56.000 --> 02:47:58.040] It means, similar to Tolman from 100 years ago [02:47:58.040 --> 02:48:00.920] until now tradition, can it solve particular tasks? [02:48:00.920 --> 02:48:02.720] Is it robust to certain tasks? [02:48:02.720 --> 02:48:06.160] Can it do flexible planning with respect to [02:48:06.160 --> 02:48:08.640] and in response to different kinds of tasks [02:48:08.640 --> 02:48:13.560] where you have maybe short or brief local changes [02:48:13.560 --> 02:48:15.960] to the environment, like a reward location changed [02:48:15.960 --> 02:48:17.340] or one edge changed? [02:48:17.340 --> 02:48:21.000] Can it integrate those to accurately plan, for instance? [02:48:21.000 --> 02:48:22.920] And we have these different graph structures. [02:48:22.920 --> 02:48:24.600] Just to give you an example of how it goes. [02:48:24.600 --> 02:48:27.520] So for graph A, domain is spatial [02:48:27.520 --> 02:48:29.360] and the task is value-based planning. [02:48:29.360 --> 02:48:30.480] What would it look like? [02:48:30.480 --> 02:48:33.520] I would describe the graph to the LLM as, [02:48:33.520 --> 02:48:35.260] you imagine a building with six rooms. [02:48:35.260 --> 02:48:37.120] From the lobby, you have two choices. [02:48:37.120 --> 02:48:38.240] You go to room one or two. [02:48:38.240 --> 02:48:40.600] From room one, there is a door to room three. [02:48:40.600 --> 02:48:42.560] From room three, there is a door to room five. [02:48:42.560 --> 02:48:44.160] In room five, there is $10. [02:48:44.160 --> 02:48:46.400] You don't take any money, because at the end, [02:48:46.400 --> 02:48:48.400] you only have one possibility to take money. [02:48:48.400 --> 02:48:49.240] You go back. [02:48:49.240 --> 02:48:52.340] From room two, you can go to four to six. [02:48:52.340 --> 02:48:55.140] And in room six, there is $50. [02:48:55.140 --> 02:48:56.400] And then the question is, [02:48:56.400 --> 02:48:58.160] and here, this was a description of the environment. [02:48:58.160 --> 02:49:00.560] Then the question is, you return to the lobby. [02:49:00.560 --> 02:49:03.560] You have only one choice to, you can only take money once. [02:49:03.560 --> 02:49:05.560] What is the optimal room to choose [02:49:05.560 --> 02:49:06.640] in order to take the most money? [02:49:06.640 --> 02:49:09.880] And you should say two, because six has the most room, right? [02:49:09.880 --> 02:49:12.520] So all of these environments are described in that way, [02:49:12.520 --> 02:49:14.840] either in the spatial domain or the social domain. [02:49:14.840 --> 02:49:17.000] And the different tasks are prompted like this. [02:49:17.000 --> 02:49:19.160] For cases where something in the environment changes, [02:49:19.160 --> 02:49:21.040] you can see how the second prompt, for instance, [02:49:21.040 --> 02:49:21.880] modifies something. [02:49:21.880 --> 02:49:24.500] Say, oh, now you learned that the reward in this room [02:49:24.500 --> 02:49:25.720] changed to such and such. [02:49:25.720 --> 02:49:28.320] Oh, now you learned that the door to this room [02:49:28.320 --> 02:49:30.400] has been changed, and it all of a sudden [02:49:30.400 --> 02:49:32.280] opens to this other, right? [02:49:32.280 --> 02:49:33.960] Okay, now with that, please don't look here. [02:49:33.960 --> 02:49:35.360] I don't want you guys to cheat. [02:49:35.360 --> 02:49:36.880] And I know you guys might have heard things, [02:49:36.880 --> 02:49:38.520] but forget everything you heard. [02:49:38.520 --> 02:49:40.760] Between these three, which one do you think [02:49:40.760 --> 02:49:43.680] is going to be the most difficult, and why? [02:49:43.680 --> 02:49:45.200] - So the choices A, B, and C. [02:49:45.200 --> 02:49:46.120] - A, B, and C. [02:49:46.120 --> 02:49:47.840] Which graph, is it gonna be difficult, [02:49:47.840 --> 02:49:49.620] or are they gonna be the same, [02:49:49.620 --> 02:49:51.760] in terms of for the LLM to solve? [02:49:51.760 --> 02:49:52.600] - They're similar. [02:49:52.600 --> 02:49:55.120] So B has more branching, and C has more length. [02:49:55.120 --> 02:49:58.120] - So which one is gonna be more difficult to solve for LLMs? [02:49:58.120 --> 02:50:02.840] You can say different things, and we can see who is right. [02:50:02.840 --> 02:50:03.680] - I don't know the answer. [02:50:03.680 --> 02:50:05.560] I'll guess B, because more branching. [02:50:05.560 --> 02:50:08.300] - Okay, anybody guesses anything else? [02:50:08.300 --> 02:50:10.640] - Probably C, because I guess LLM is not able [02:50:10.640 --> 02:50:13.040] to handle a very long-term sequence. [02:50:13.040 --> 02:50:14.520] - So we have two hypotheses here. [02:50:14.520 --> 02:50:16.280] Anybody thinks they're the same? [02:50:16.280 --> 02:50:18.320] - So between A, B, and C. [02:50:18.320 --> 02:50:19.280] - Between A, B, and C, which one is more difficult, [02:50:19.280 --> 02:50:20.320] or are they the same? [02:50:21.400 --> 02:50:22.840] - I just don't understand, when you say so, [02:50:22.840 --> 02:50:24.440] what kind of problem are you trying to solve? [02:50:24.440 --> 02:50:26.680] - This problem that we just mentioned here. [02:50:26.680 --> 02:50:29.000] There is some money somewhere at the end of them. [02:50:29.000 --> 02:50:31.460] One of the nodes that is terminal has the most money. [02:50:31.460 --> 02:50:32.300] - C is harder. [02:50:32.300 --> 02:50:33.120] - C is harder, okay. [02:50:33.120 --> 02:50:33.960] And then between D and E, [02:50:33.960 --> 02:50:35.520] which one do you think is harder? [02:50:35.520 --> 02:50:41.400] - More branching, so E? [02:50:41.400 --> 02:50:42.400] - You think E is harder? [02:50:42.400 --> 02:50:43.600] - If it's branching. [02:50:43.600 --> 02:50:45.080] I don't actually know that. [02:50:45.080 --> 02:50:48.560] I do feel like he has a point, so I can be wrong. [02:50:48.560 --> 02:50:49.920] - Okay, so you think E is harder. [02:50:49.920 --> 02:50:51.320] Anybody thinks D is harder? [02:50:51.320 --> 02:50:53.040] - Maybe D is harder. [02:50:53.040 --> 02:50:54.120] - Okay, why? [02:50:54.120 --> 02:50:56.840] - Because it has less weight to go from one point to another. [02:50:56.840 --> 02:50:58.720] - It has bottlenecks, right. [02:50:58.720 --> 02:51:00.720] Okay, ready? [02:51:00.720 --> 02:51:01.560] - Okay. [02:51:01.560 --> 02:51:02.380] - Right here. [02:51:02.380 --> 02:51:03.220] So take a look at this. [02:51:03.220 --> 02:51:05.520] B is harder than C, as you can see. [02:51:05.520 --> 02:51:06.360] - It's branching. [02:51:06.360 --> 02:51:07.400] - Right, B is harder. [02:51:07.400 --> 02:51:10.200] Even though C is twice as large as C [02:51:10.200 --> 02:51:11.960] in terms of the number of nodes. [02:51:11.960 --> 02:51:14.360] And A, you can see that it was easy, right? [02:51:14.360 --> 02:51:17.520] So imagine if I showed you this as the planning task, [02:51:17.520 --> 02:51:18.720] and I declared victory, and I said, [02:51:18.720 --> 02:51:20.000] look, LLMs can solve planning. [02:51:20.000 --> 02:51:22.680] GPT-4, great, near 100%, right? [02:51:22.680 --> 02:51:24.520] But then you try just a little longer, [02:51:24.520 --> 02:51:25.960] or you'd have the same number of nodes, [02:51:25.960 --> 02:51:27.480] but with a branching structure. [02:51:27.480 --> 02:51:29.200] What do you see here? [02:51:29.200 --> 02:51:30.760] Huge drop, right? [02:51:30.760 --> 02:51:33.120] And in fact, what do you see for three of the LLMs? [02:51:33.120 --> 02:51:35.480] It's at almost a zero percent, right? [02:51:35.480 --> 02:51:37.560] And now between D and E, let's take a look. [02:51:37.560 --> 02:51:41.200] As you can see, D is much more difficult for GPT-4, [02:51:41.200 --> 02:51:42.160] which is the blue one. [02:51:42.160 --> 02:51:44.680] In fact, E is more difficult than B. [02:51:44.680 --> 02:51:46.440] Sorry, B is more difficult than E, [02:51:46.440 --> 02:51:47.280] even though it's much smaller. [02:51:47.280 --> 02:51:48.120] - It's not consistent, yeah. [02:51:48.120 --> 02:51:50.000] - Well, there is something there. [02:51:50.000 --> 02:51:53.040] In these two, you have structures [02:51:53.040 --> 02:51:54.560] where you need to be exact. [02:51:54.560 --> 02:51:57.680] There is not multiple paths between different nodes, right? [02:51:57.680 --> 02:51:58.520] So it's very important. [02:51:58.520 --> 02:52:00.640] If you're going from this cluster to this one, [02:52:00.640 --> 02:52:02.920] you have to path through this bottleneck. [02:52:02.920 --> 02:52:04.720] So there needs to be an ability [02:52:04.720 --> 02:52:08.160] to plan accurately the specific bottleneck, correct? [02:52:08.160 --> 02:52:09.600] Now, what about the different tasks? [02:52:09.600 --> 02:52:10.760] Let's see. [02:52:10.760 --> 02:52:12.120] As you can see, they're not robust [02:52:12.120 --> 02:52:13.800] to the different tasks either. [02:52:13.800 --> 02:52:16.560] Traversal, which is one step, two step, [02:52:16.560 --> 02:52:18.960] three step, end step path, and value path. [02:52:18.960 --> 02:52:21.060] This is easier for these guys. [02:52:21.060 --> 02:52:22.420] Why is that? [02:52:22.420 --> 02:52:25.960] The reason is that traversal does not change [02:52:25.960 --> 02:52:28.180] the structure of the environment or the rewards. [02:52:28.180 --> 02:52:30.060] However, as soon as you have the local change, [02:52:30.060 --> 02:52:31.920] the stuff that Edward Tolman was talking about [02:52:31.920 --> 02:52:33.560] a hundred years ago that is required [02:52:33.560 --> 02:52:35.720] for measuring cognitive maps in rodents, for instance, [02:52:35.720 --> 02:52:37.680] like detour and shortcut that we have, [02:52:37.680 --> 02:52:38.760] all of a sudden you see a drop, [02:52:38.760 --> 02:52:40.840] and you can see all of a sudden it goes to zero [02:52:40.840 --> 02:52:45.440] and for cohere, for alpaca, and for llama, right? [02:52:45.440 --> 02:52:49.320] And so, and here you can see this sad thing also. [02:52:49.320 --> 02:52:52.600] It's almost at 0% for four of the graphs. [02:52:52.600 --> 02:52:55.800] So all of these graphs are at almost at 0% [02:52:55.800 --> 02:53:00.800] for three of the LLMs, and about 20% for most of them, [02:53:00.800 --> 02:53:03.080] and it's only GPT-4 that does a little better, [02:53:03.080 --> 02:53:05.520] and that's about 40%, right? [02:53:05.520 --> 02:53:07.040] So based on all of these things, [02:53:07.040 --> 02:53:10.740] robustness to task if you aggregate across graphs, [02:53:10.740 --> 02:53:14.140] not robust to tasks, and robustness to different graphs [02:53:14.140 --> 02:53:17.000] if you aggregate across tasks, also not very robust. [02:53:17.000 --> 02:53:20.440] So you compare these, the general conclusion I would draw [02:53:20.440 --> 02:53:22.380] is that they're not good at planning. [02:53:22.380 --> 02:53:25.040] Now let's take a look at some of their failure modes. [02:53:25.040 --> 02:53:27.400] So can you guys see what is the failure mode [02:53:27.400 --> 02:53:29.060] that is happening here? [02:53:29.060 --> 02:53:30.760] There is an edge that doesn't exist. [02:53:30.760 --> 02:53:33.860] It hallucinated an edge in giving the planning response [02:53:33.860 --> 02:53:35.360] that doesn't exist. [02:53:35.360 --> 02:53:37.140] Now let's take a look at this case, [02:53:37.140 --> 02:53:40.920] where you have a direct path from one to seven, [02:53:40.920 --> 02:53:42.240] but it's giving a very long, it says, [02:53:42.240 --> 02:53:44.040] "What is the shortest path between one and seven?" [02:53:44.040 --> 02:53:46.080] And it says one, 13, 10, seven. [02:53:46.080 --> 02:53:49.240] But interestingly, if I ask the LLM, [02:53:49.240 --> 02:53:51.000] can you list the tuples? [02:53:51.000 --> 02:53:53.120] GPT-4 can easily list the tuples, [02:53:53.120 --> 02:53:54.880] but at the same time still can hallucinate, [02:53:54.880 --> 02:53:56.240] like in this case. [02:53:56.240 --> 02:53:58.680] Now in this last one, there's two mistakes. [02:53:58.680 --> 02:54:00.020] I told you one of the mistakes, [02:54:00.020 --> 02:54:01.260] which is hallucinating the edges. [02:54:01.260 --> 02:54:02.760] What other mistake do you see? [02:54:02.760 --> 02:54:04.200] - Is it out of order somehow? [02:54:04.200 --> 02:54:05.040] I don't know. [02:54:05.040 --> 02:54:06.880] It's hard to tell from this distance. [02:54:06.880 --> 02:54:07.880] - Take a look at the answer. [02:54:07.880 --> 02:54:09.320] What is wrong with the answer? [02:54:09.320 --> 02:54:10.480] - It revisits a node. [02:54:10.480 --> 02:54:12.320] - Exactly, there's a loop. [02:54:12.320 --> 02:54:14.600] So a shortest path should not have a loop. [02:54:14.600 --> 02:54:15.440] - Of course. [02:54:15.440 --> 02:54:16.280] - So we found another case, right? [02:54:16.280 --> 02:54:17.100] - Yeah, yeah. [02:54:17.100 --> 02:54:19.880] - So these three failure modes are failures of planning. [02:54:19.880 --> 02:54:22.700] Even though it knows the one-step tuples correctly, [02:54:22.700 --> 02:54:24.600] it seems to fail at planning. [02:54:24.600 --> 02:54:27.800] And it can give you some insight into what is going on. [02:54:27.800 --> 02:54:30.360] So it's not very good at stitching one-step things together. [02:54:30.360 --> 02:54:31.400] So based on that, [02:54:31.400 --> 02:54:33.600] why do you think it was better at graph A? [02:54:33.600 --> 02:54:35.240] Can people give me guesses? [02:54:35.240 --> 02:54:37.640] Why do you think graph A was easier? [02:54:37.640 --> 02:54:39.880] Or it showed some apparent success on graph A. [02:54:39.880 --> 02:54:41.920] Why do you think that is? [02:54:41.920 --> 02:54:43.760] - Well, it has fewer choices. [02:54:43.760 --> 02:54:44.920] - Fewer choices. [02:54:44.920 --> 02:54:46.640] But this one is also very few choices. [02:54:46.640 --> 02:54:48.160] It's like B. [02:54:48.160 --> 02:54:49.920] It's a tree, right? [02:54:49.920 --> 02:54:51.960] This has fewer choices than that, right? [02:54:51.960 --> 02:54:53.920] But why is this so much more difficult? [02:54:53.920 --> 02:54:58.360] - More ways to be wrong. [02:54:58.360 --> 02:54:59.200] - Say again? [02:54:59.200 --> 02:55:00.200] - More ways to be wrong. [02:55:00.200 --> 02:55:02.400] - More ways to be wrong. [02:55:02.400 --> 02:55:04.440] So another way to say it [02:55:04.440 --> 02:55:06.840] is that the things that showed up exact [02:55:06.840 --> 02:55:08.840] in the kind of the prompt [02:55:08.840 --> 02:55:11.320] are more likely to work for C and A, basically. [02:55:11.320 --> 02:55:13.160] So if it just did just memorization [02:55:13.160 --> 02:55:15.080] of what's going on, right? [02:55:15.080 --> 02:55:20.080] Because it's just sort of a kind of a two tracks here. [02:55:20.080 --> 02:55:21.960] But there were more branching. [02:55:21.960 --> 02:55:22.880] - For what it's worth, [02:55:22.880 --> 02:55:25.760] I think a lot of the common sense reasoning benchmarks [02:55:25.760 --> 02:55:28.240] that these things are specifically trained on [02:55:28.240 --> 02:55:29.440] are transitive. [02:55:29.440 --> 02:55:32.240] I don't know what you call this. [02:55:32.240 --> 02:55:33.640] - Yeah. [02:55:33.640 --> 02:55:35.520] - We trained it to be good at A and C. [02:55:35.520 --> 02:55:36.360] - No, that's not true. [02:55:36.360 --> 02:55:37.200] - No? [02:55:37.200 --> 02:55:40.240] - GPT-4 has been trained on a huge amount of text. [02:55:40.240 --> 02:55:42.120] A lot of that is family trees [02:55:42.120 --> 02:55:44.520] and structures that are actually tree-like. [02:55:44.520 --> 02:55:46.120] It turns out transformers, in fact, [02:55:46.120 --> 02:55:48.680] do have some limitations with tree-like structures [02:55:48.680 --> 02:55:49.880] and with things that are bottleneck. [02:55:49.880 --> 02:55:51.440] We are very good at bottleneck. [02:55:51.440 --> 02:55:53.880] In fact, bottlenecks makes things easier for us, right? [02:55:53.880 --> 02:55:58.320] You have a few nodes that are basically, [02:55:58.320 --> 02:56:00.960] they have high centrality, [02:56:00.960 --> 02:56:04.240] especially eigencentrality or betweenness centrality. [02:56:04.240 --> 02:56:05.840] And basically what you do is [02:56:05.840 --> 02:56:07.880] when you're solving a problem in planning, [02:56:07.880 --> 02:56:08.800] you say, "I'm going to find that. [02:56:08.800 --> 02:56:10.480] "No, then from there I'll go somewhere." [02:56:10.480 --> 02:56:11.680] You have a subway system. [02:56:11.680 --> 02:56:13.400] You go to 14th station in New York City, [02:56:13.400 --> 02:56:15.520] then you can find a train that goes somewhere else, right? [02:56:15.520 --> 02:56:17.640] So if you get lost, just find a hub. [02:56:17.640 --> 02:56:20.000] We actually use these heuristics a lot. [02:56:20.000 --> 02:56:22.960] It's available in human texts a lot, [02:56:22.960 --> 02:56:25.240] but it hasn't been picking up on that. [02:56:25.240 --> 02:56:27.600] So this is about the structures that the transformer, [02:56:27.600 --> 02:56:29.160] for instance, might have been learning. [02:56:29.160 --> 02:56:30.040] And as you can see, [02:56:30.040 --> 02:56:32.640] you have here from seven billion parameters [02:56:32.640 --> 02:56:33.920] to one trillion parameters, [02:56:33.920 --> 02:56:36.360] to the best of our knowledge, or larger, right? [02:56:36.360 --> 02:56:38.560] And none of them is capable of figuring out [02:56:38.560 --> 02:56:41.920] or having a high performance higher than like, [02:56:41.920 --> 02:56:45.120] we have something between zero and 40% [02:56:45.120 --> 02:56:46.720] on a simple two-step tree, [02:56:46.720 --> 02:56:47.800] which is the simplest thing [02:56:47.800 --> 02:56:49.200] you can give a model-based planner. [02:56:49.200 --> 02:56:51.840] It's not even probabilistic, it's deterministic. [02:56:51.840 --> 02:56:54.600] And even that is failing, right? [02:56:54.600 --> 02:56:56.560] And then we saw these failure modes. [02:56:56.560 --> 02:56:59.560] Another thing, what if I give it extra instructions? [02:56:59.560 --> 02:57:02.520] By the way, all of these have been told things step-by-step, [02:57:02.520 --> 02:57:04.200] so we give that simple chain of thought. [02:57:04.200 --> 02:57:06.160] What if we give it extra instructions? [02:57:06.160 --> 02:57:09.480] For instance, I describe entire breadth-first search [02:57:09.480 --> 02:57:10.400] and depth-first search, [02:57:10.400 --> 02:57:12.520] and I say, "Hey, use depth-first search. [02:57:12.520 --> 02:57:13.680] "How is that working? [02:57:13.680 --> 02:57:15.440] "First do this, then do that." [02:57:15.440 --> 02:57:16.280] And another one. [02:57:16.280 --> 02:57:18.400] So you can see in the supplementary material of our paper, [02:57:18.400 --> 02:57:19.960] the entire sort of breadth-first search [02:57:19.960 --> 02:57:21.480] and depth-first search. [02:57:21.480 --> 02:57:23.920] Then you see that it improves somewhat [02:57:23.920 --> 02:57:26.400] for when you are within a cluster, [02:57:26.400 --> 02:57:29.320] but when you look a situation in this graph D, [02:57:29.320 --> 02:57:31.760] where you need to find the shortest path [02:57:31.760 --> 02:57:35.560] between nodes that are a cluster away from each other, [02:57:35.560 --> 02:57:37.280] what you see is that it doesn't help much. [02:57:37.280 --> 02:57:39.240] And interestingly, for different temperatures, [02:57:39.240 --> 02:57:41.280] for temperature zero, it doesn't help at all. [02:57:41.280 --> 02:57:43.240] It helps a little bit for temperatures that are higher, [02:57:43.240 --> 02:57:45.480] and I guess take different kind of paths. [02:57:45.480 --> 02:57:48.480] But it's interesting, only one cluster away. [02:57:48.480 --> 02:57:50.720] Shortest path one cluster away is not a big deal. [02:57:50.720 --> 02:57:53.800] The diameter of this network is not that large. [02:57:53.800 --> 02:57:55.320] There is not a lot of improvement, [02:57:55.320 --> 02:57:57.840] and the performance is pretty low, as you can see, [02:57:57.840 --> 02:57:58.680] for all of them. [02:57:58.680 --> 02:58:01.560] And for three of them, it's actually closer to zero. [02:58:01.560 --> 02:58:04.280] So this is the evaluation. [02:58:04.280 --> 02:58:07.720] We have done, together with my summer interns, [02:58:07.720 --> 02:58:09.640] we have a paper where we did [02:58:09.640 --> 02:58:12.960] a prefrontal cortex-inspired modular architecture [02:58:12.960 --> 02:58:14.960] where GPT-4 basically plays the role [02:58:14.960 --> 02:58:17.320] of these different kind of modules [02:58:17.320 --> 02:58:20.200] and solves these problems in a kind of a modular way, [02:58:20.200 --> 02:58:21.600] similar to the prefrontal cortex. [02:58:21.600 --> 02:58:24.480] I have like 15 years of working on prefrontal cortex. [02:58:24.480 --> 02:58:26.400] I'm very excited to do this with these models. [02:58:26.400 --> 02:58:27.840] You can see it here. [02:58:27.840 --> 02:58:30.400] And this paper is, you can find it on my website, [02:58:30.400 --> 02:58:32.720] webatal2023, and you can find it here as well. [02:58:32.720 --> 02:58:35.720] I have an archive number over there. [02:58:35.720 --> 02:58:37.600] - Okay, so I had to cut it for time there, [02:58:37.600 --> 02:58:39.360] but literally, I'm not joking, [02:58:39.360 --> 02:58:42.400] I had another half an hour of audio just chatting with her [02:58:42.400 --> 02:58:44.640] and all of us just crowding around her like students. [02:58:44.640 --> 02:58:47.400] She just was very, very engaging in person. [02:58:47.400 --> 02:58:48.400] And I love to see that. [02:58:48.400 --> 02:58:52.840] I love to see when people can not only do great work, [02:58:52.840 --> 02:58:56.000] but then also talk in a compelling fashion about it, [02:58:56.000 --> 02:58:58.720] not just passively answer questions about it, [02:58:58.720 --> 02:59:00.760] but also challenge you to think along the way. [02:59:00.760 --> 02:59:03.200] So I guess if I were to include one agent's paper [02:59:03.200 --> 02:59:05.240] from NeurIPS, this would be it. [02:59:05.240 --> 02:59:07.880] And for the final talk of this entire pod, [02:59:07.880 --> 02:59:10.200] which is already stretching into three hours, [02:59:10.200 --> 02:59:12.760] I have saved for the coverage of state space models, [02:59:12.760 --> 02:59:15.200] which have been the talk of the town. [02:59:15.200 --> 02:59:17.920] The Mamba model was released a few days before NeurIPS [02:59:17.920 --> 02:59:19.000] and Albert Gu was there. [02:59:19.000 --> 02:59:21.880] I met him, but I couldn't get a conversation with him. [02:59:21.880 --> 02:59:24.960] But Chris Ray was on stage talking about effectively [02:59:24.960 --> 02:59:27.400] all of hazy research, what Stanford's doing [02:59:27.400 --> 02:59:28.560] and what Chris Ray is up to [02:59:28.560 --> 02:59:30.160] and all the people he's associated with, [02:59:30.160 --> 02:59:32.600] including Tri Dao and Albert Gu. [02:59:32.600 --> 02:59:36.440] So if you want a primer or a good entry point [02:59:36.440 --> 02:59:39.760] on just how Chris Ray is thinking about state space models, [02:59:39.760 --> 02:59:41.240] I think this is it. [02:59:41.240 --> 02:59:42.440] - So as I mentioned, our motivation [02:59:42.440 --> 02:59:46.120] for getting rid of attention, potentially, is long sequences. [02:59:46.120 --> 02:59:47.280] That's the practical motivation. [02:59:47.280 --> 02:59:50.360] I'll come back to my real motivation in one slide. [02:59:50.360 --> 02:59:52.880] Practically, some data comes as long sequences. [02:59:52.880 --> 02:59:56.200] Data, audio, DNA is billions of base pairs. [02:59:56.200 --> 02:59:58.360] We can also cram in tons of few-shot examples, [02:59:58.360 --> 02:59:59.660] which seems pretty cool. [02:59:59.660 --> 03:00:01.920] When we started this project, [03:00:01.920 --> 03:00:03.440] really the standard models couldn't have it. [03:00:03.440 --> 03:00:06.280] GPT-1 had only a 512 context length. [03:00:06.280 --> 03:00:08.680] And as I mentioned, transformers are scaling quadratically [03:00:08.680 --> 03:00:09.680] in their sequence length. [03:00:09.680 --> 03:00:12.080] So we kind of took two parallel paths to this. [03:00:12.080 --> 03:00:14.200] One is better hardware algorithms. [03:00:14.200 --> 03:00:16.220] So we tried with flash attention [03:00:16.220 --> 03:00:17.520] and now people have followed up [03:00:17.520 --> 03:00:19.720] to make that path really, really fast. [03:00:19.720 --> 03:00:21.560] Just optimize the crap out of it on hardware [03:00:21.560 --> 03:00:24.240] and there's a lot of juice to squeeze there. [03:00:24.240 --> 03:00:25.800] The other approach, which I'll talk about now, [03:00:25.800 --> 03:00:26.960] are new models. [03:00:27.800 --> 03:00:29.120] Now, as I mentioned, [03:00:29.120 --> 03:00:31.200] I actually wasn't totally motivated by that. [03:00:31.200 --> 03:00:33.820] I wasn't, honestly, that wasn't my total motivation. [03:00:33.820 --> 03:00:36.840] I was really motivated by this inductive bias issue. [03:00:36.840 --> 03:00:39.840] So the idea here is you give me this image [03:00:39.840 --> 03:00:42.800] and I flatten it into one single pixel. [03:00:42.800 --> 03:00:44.720] And then I ask you, is it a car or a boat? [03:00:44.720 --> 03:00:45.800] Some CIFAR-like thing. [03:00:45.800 --> 03:00:48.280] Sequential CIFAR, if you know the task. [03:00:48.280 --> 03:00:49.500] And this is really interesting to me [03:00:49.500 --> 03:00:52.040] because when a human would do this, [03:00:52.040 --> 03:00:52.880] this would be hopeless. [03:00:52.880 --> 03:00:53.700] If you gave me a picture [03:00:53.700 --> 03:00:56.360] and gave me a one-pixel vector as a flat thing, [03:00:56.360 --> 03:00:58.360] I would have no chance of classifying it. [03:00:58.360 --> 03:01:01.080] Machines could do something, but there was a huge gap. [03:01:01.080 --> 03:01:02.360] And I wanted to understand [03:01:02.360 --> 03:01:05.240] why is there this inductive bias underneath the covers? [03:01:05.240 --> 03:01:07.400] Do you really need this spatial inductive bias [03:01:07.400 --> 03:01:08.360] for the machines to reason? [03:01:08.360 --> 03:01:11.240] Do they have to reason like us when they do this? [03:01:11.240 --> 03:01:13.280] So I was fascinated by this problem. [03:01:13.280 --> 03:01:14.120] All right. [03:01:14.120 --> 03:01:16.780] So there's another benchmark that came out [03:01:16.780 --> 03:01:19.160] that was really exciting from the Google folks, [03:01:19.160 --> 03:01:22.680] which was about how to benchmark efficient attention. [03:01:22.680 --> 03:01:23.800] It's called Long Range Arena. [03:01:23.800 --> 03:01:25.060] It's extremely cool. [03:01:25.060 --> 03:01:27.140] We found them basically because we were playing around [03:01:27.140 --> 03:01:28.600] with these sequential CIFAR things, [03:01:28.600 --> 03:01:30.840] and they had a much greater library of places [03:01:30.840 --> 03:01:34.600] where they were seeing possibilities to improve attention. [03:01:34.600 --> 03:01:37.520] This was the leaderboard in 2021 of this attention, [03:01:37.520 --> 03:01:38.840] and they were basically looking at a bunch [03:01:38.840 --> 03:01:40.620] of very cool linear attention variants, [03:01:40.620 --> 03:01:42.880] some of which we still play with. [03:01:42.880 --> 03:01:45.780] I want to draw your attention to two columns on this thing. [03:01:45.780 --> 03:01:46.960] The first is image. [03:01:46.960 --> 03:01:49.680] That is that sequential CIFAR task I was just talking about. [03:01:49.680 --> 03:01:51.200] It's a really interesting task. [03:01:51.200 --> 03:01:54.440] You've probably trained CIFAR to 90s or high 80s [03:01:54.440 --> 03:01:56.880] on your laptop or on a small GPU, [03:01:56.880 --> 03:01:58.160] and you see the sequential version [03:01:58.160 --> 03:02:00.360] was lagging quite a bit behind. [03:02:00.360 --> 03:02:02.120] The other column is this thing, PathX, [03:02:02.120 --> 03:02:04.640] which were these large images where you had two dots, [03:02:04.640 --> 03:02:06.920] and you're trying to say, are the two dots connected? [03:02:06.920 --> 03:02:08.800] And the reason there are Xs is that every model [03:02:08.800 --> 03:02:10.960] was basically random guessing at this point. [03:02:10.960 --> 03:02:13.980] So there's three approaches that we were trying [03:02:13.980 --> 03:02:15.320] to improve long sequences. [03:02:15.320 --> 03:02:19.920] Improve the utilization on hardware, approximate attention, [03:02:19.920 --> 03:02:22.280] and this last one, which I'm gonna talk about most, [03:02:22.280 --> 03:02:24.360] which is using RNN-based kinds of ideas, [03:02:24.360 --> 03:02:25.800] and signal processing ideas. [03:02:25.800 --> 03:02:27.960] All of them are great. [03:02:27.960 --> 03:02:30.520] I just happened to pick the last one. [03:02:30.520 --> 03:02:31.760] So the idea is we're gonna replace [03:02:31.760 --> 03:02:34.880] just the signal processing box, the signal mixing box, [03:02:34.880 --> 03:02:37.540] with this new operator, S4, [03:02:37.540 --> 03:02:39.540] that's based on signal processing ideas. [03:02:39.540 --> 03:02:42.200] So this was inspired by Albert and Karn. [03:02:42.200 --> 03:02:44.320] Albert's now a professor at CMU. [03:02:44.320 --> 03:02:45.920] Karn is now running this company, Cartesia, [03:02:45.920 --> 03:02:48.380] which is a small company, just started. [03:02:48.380 --> 03:02:50.920] And basically, S4 is a classic state-space model. [03:02:50.920 --> 03:02:52.760] So if you're an EE person, you've seen these [03:02:52.760 --> 03:02:54.280] in like your undergrad right away. [03:02:54.280 --> 03:02:55.560] It's an LTI system. [03:02:55.560 --> 03:02:57.880] But we're gonna tweak it for deep learning. [03:02:57.880 --> 03:02:59.940] The first thing we're gonna get, as I'll show you, [03:02:59.940 --> 03:03:01.200] pretty mathematically and nicely, [03:03:01.200 --> 03:03:02.440] is that signal processing people [03:03:02.440 --> 03:03:04.260] are obsessed with stability. [03:03:04.260 --> 03:03:06.840] They understand bounded input, bounded output stability [03:03:06.840 --> 03:03:07.680] like nobody's business. [03:03:07.680 --> 03:03:10.600] It's simple and it's clean, and we can use it right away. [03:03:10.600 --> 03:03:13.520] This is a challenge when training these models. [03:03:13.520 --> 03:03:15.900] A second thing which was quite surprising is, [03:03:15.900 --> 03:03:17.800] I've always thought about CNNs and RNNs [03:03:17.800 --> 03:03:19.360] as quite distinct models. [03:03:19.360 --> 03:03:21.200] But what I'm gonna show you mathematically is, [03:03:21.200 --> 03:03:23.240] these models actually unify both. [03:03:23.240 --> 03:03:25.360] Now, these are CNNs in a kind of different way [03:03:25.360 --> 03:03:26.260] than we're used to. [03:03:26.260 --> 03:03:27.960] They're convolutions where the filters [03:03:27.960 --> 03:03:29.780] are potentially as long as the input. [03:03:29.780 --> 03:03:32.080] But we're gonna be able to view the exact same weights [03:03:32.080 --> 03:03:35.760] and operate on them either as an RNN or a CNN, [03:03:35.760 --> 03:03:36.960] which is quite exciting. [03:03:36.960 --> 03:03:39.240] And the last piece, of course, [03:03:39.240 --> 03:03:40.880] is that we're gonna make this quite fast. [03:03:40.880 --> 03:03:42.120] And these are gonna be asymptotically [03:03:42.120 --> 03:03:43.640] more efficient than transformers. [03:03:43.640 --> 03:03:45.480] We're eventually gonna be able to process sequence [03:03:45.480 --> 03:03:48.880] in like n log n time, which is then a challenge [03:03:48.880 --> 03:03:52.240] to make practical, and I'll share some results there. [03:03:52.240 --> 03:03:54.240] Now, this thing is extremely simple. [03:03:54.240 --> 03:03:56.680] Very simple, very simple signal processing ideas. [03:03:56.680 --> 03:03:58.600] But I just wanna point out it had a large improvement [03:03:58.600 --> 03:04:00.520] on LRA that surprised me. [03:04:00.520 --> 03:04:02.380] So here's the improvement on LRA. [03:04:02.380 --> 03:04:04.520] This is the first of its kind to solve PathX. [03:04:04.520 --> 03:04:07.080] It was like a 26-point jump on this benchmark [03:04:07.080 --> 03:04:09.600] that a bunch of folks had played at. [03:04:09.600 --> 03:04:12.320] I also wanna point out that the image task, [03:04:12.320 --> 03:04:14.900] that spatial bias seems to matter less than I thought. [03:04:14.900 --> 03:04:16.880] And that was really the thing that was interesting to me. [03:04:16.880 --> 03:04:18.960] And since then, many people have followed on [03:04:18.960 --> 03:04:20.560] and pushed these numbers up higher, [03:04:20.560 --> 03:04:21.720] but I just think that's really interesting. [03:04:21.720 --> 03:04:23.320] I don't know what to do with the observation, [03:04:23.320 --> 03:04:25.160] but I really like it. [03:04:25.160 --> 03:04:28.180] Okay, so what is signal processing? [03:04:28.180 --> 03:04:30.280] Well, signal processing people view a signal [03:04:30.280 --> 03:04:33.100] of d dimensions at n time steps as input, [03:04:33.100 --> 03:04:35.900] and an output is a signal of d dimension at n time steps. [03:04:35.900 --> 03:04:37.800] That looks a lot like our X and O matrix [03:04:37.800 --> 03:04:39.760] that we had in attention. [03:04:39.760 --> 03:04:41.080] They also think causally. [03:04:41.080 --> 03:04:43.600] They think that time moves left to right through this, [03:04:43.600 --> 03:04:46.040] and things like GPT are also kind of causal. [03:04:46.040 --> 03:04:48.440] So so far, what I wanna emphasize is, [03:04:48.440 --> 03:04:49.560] we've really done nothing. [03:04:49.560 --> 03:04:51.340] It's just symbol pushing that we've been able [03:04:51.340 --> 03:04:52.880] to move into this model. [03:04:52.880 --> 03:04:55.320] So what does signal processing actually buy us? [03:04:55.320 --> 03:04:57.200] Two big ideas. [03:04:57.200 --> 03:04:59.600] The first is, over 100 years, [03:04:59.600 --> 03:05:01.360] they figured out a bunch of models [03:05:01.360 --> 03:05:03.880] which are relatively simple, [03:05:03.880 --> 03:05:06.060] but capture pretty interesting phenomenon. [03:05:06.060 --> 03:05:08.480] These aren't the best models you could ever use, [03:05:08.480 --> 03:05:10.600] these LTI systems, but they're a simple [03:05:10.600 --> 03:05:12.480] and very well-understood starting point. [03:05:12.480 --> 03:05:14.440] So I argue, makes sense to start there. [03:05:15.560 --> 03:05:17.820] The second piece, which I think a lot of machine learners [03:05:17.820 --> 03:05:21.160] don't necessarily love, is that they have this idea [03:05:21.160 --> 03:05:23.400] that a signal is a continuous object [03:05:23.400 --> 03:05:25.640] that then is discreetly sampled. [03:05:25.640 --> 03:05:28.200] And that idea allows us to do a bunch of stuff. [03:05:28.200 --> 03:05:30.680] In particular, it allows us to use all our discrete tricks [03:05:30.680 --> 03:05:33.040] which are more common in machine learning and AI, [03:05:33.040 --> 03:05:36.360] but also a bunch of 20th century mathematics [03:05:36.360 --> 03:05:38.600] that knows how to do integrals and solves things exactly. [03:05:38.600 --> 03:05:40.160] And I'll show you at least one of those tricks [03:05:40.160 --> 03:05:42.080] in the next couple of slides. [03:05:42.080 --> 03:05:43.780] I think it's an incredibly powerful idea, [03:05:43.780 --> 03:05:46.360] and it was really helpful for us to think about it. [03:05:46.360 --> 03:05:49.000] And as I said, it's gonna teach us about stability [03:05:49.000 --> 03:05:50.120] in like a trivial way. [03:05:50.120 --> 03:05:52.000] We're gonna use theorems from the 1800s [03:05:52.000 --> 03:05:54.160] to be able to prove that our models are stable, [03:05:54.160 --> 03:05:56.440] which I just think is awesome. [03:05:56.440 --> 03:05:57.560] All right, so what's an LTI system? [03:05:57.560 --> 03:05:58.720] If you've never played with one, [03:05:58.720 --> 03:06:01.800] this is what's called a single input, single output system. [03:06:01.800 --> 03:06:03.440] You have some curve that's coming in, [03:06:03.440 --> 03:06:05.760] which is typically called UT, that's the input, [03:06:05.760 --> 03:06:07.980] and some output curve, YT. [03:06:07.980 --> 03:06:09.480] You have some hidden state, [03:06:09.480 --> 03:06:11.160] which is much higher dimension usually [03:06:11.160 --> 03:06:12.360] than the input and the output. [03:06:12.360 --> 03:06:14.900] We'll take the hidden state as large as the input [03:06:14.900 --> 03:06:15.800] when it's discretized. [03:06:15.800 --> 03:06:18.280] It's gonna be a huge thing, okay? [03:06:18.280 --> 03:06:21.680] Now, I haven't told you how that hidden state evolves yet, [03:06:21.680 --> 03:06:23.120] but it's gonna be constrained. [03:06:23.120 --> 03:06:25.520] And the LTI, people say, there's lots of things [03:06:25.520 --> 03:06:28.040] that can fit into basically letting it evolve [03:06:28.040 --> 03:06:29.520] according to an ODE, okay? [03:06:29.520 --> 03:06:32.360] So I'm gonna show you that in just one second. [03:06:32.360 --> 03:06:34.080] So here's what you need for the ODE. [03:06:34.080 --> 03:06:35.600] You need two matrices, A and B, [03:06:35.600 --> 03:06:37.800] and we're gonna learn those matrices. [03:06:37.800 --> 03:06:39.480] And basically, it says that the hidden state [03:06:39.480 --> 03:06:41.440] can only evolve according to this equation. [03:06:41.440 --> 03:06:43.920] It basically says the change in the hidden state [03:06:43.920 --> 03:06:45.960] is proportional to some learned function [03:06:45.960 --> 03:06:48.560] of the input plus the previous state, okay? [03:06:48.560 --> 03:06:51.660] The output is then from projection, [03:06:51.660 --> 03:06:52.660] from this linear projection, [03:06:52.660 --> 03:06:55.420] from this high dimensional state down to 1D. [03:06:55.420 --> 03:06:58.280] This is all that an LTI system does, okay? [03:06:58.280 --> 03:07:00.600] I'm just saying it's something that's surprisingly powerful [03:07:00.600 --> 03:07:01.440] and well understood. [03:07:01.440 --> 03:07:02.500] This is not the best model. [03:07:02.500 --> 03:07:03.680] If you're a signal processing person, [03:07:03.680 --> 03:07:05.080] you say, oh, you should use X, Y, or Z. [03:07:05.080 --> 03:07:06.200] You're probably right, [03:07:06.200 --> 03:07:07.840] but we wanna start with something really, really simple [03:07:07.840 --> 03:07:10.040] that we can understand all the way. [03:07:10.040 --> 03:07:12.880] All right, so it turns out that one of the beautiful things [03:07:12.880 --> 03:07:14.840] is because it has this continuous object [03:07:14.840 --> 03:07:15.900] lurking in the background, [03:07:15.900 --> 03:07:18.160] you can use high school calculus. [03:07:18.160 --> 03:07:20.680] And in particular, you can get out this nice expression. [03:07:20.680 --> 03:07:22.160] And what this says is the hidden state [03:07:22.160 --> 03:07:24.880] is exactly this function, X of S, [03:07:24.880 --> 03:07:27.920] and this convolutional style integral, okay? [03:07:27.920 --> 03:07:29.320] This is exactly what it is. [03:07:29.320 --> 03:07:32.080] This is wonderful, you just solve the ODE. [03:07:32.080 --> 03:07:34.160] Then when we realize it, we have to discretize. [03:07:34.160 --> 03:07:36.440] We'll come back to that in a second. [03:07:36.440 --> 03:07:38.040] So the immediate win is, [03:07:38.960 --> 03:07:42.320] well, this can tell us exactly when the system is stable. [03:07:42.320 --> 03:07:44.000] Basically, as long as the eigenvalues [03:07:44.000 --> 03:07:45.400] are in the left-hand part of the plane, [03:07:45.400 --> 03:07:47.840] which every EE person memorizes, [03:07:47.840 --> 03:07:50.160] and the reason left-hand part of the complex plane matters [03:07:50.160 --> 03:07:53.480] is E to those values goes inside the unit disk, [03:07:53.480 --> 03:07:55.240] you know that this thing is not gonna blow up. [03:07:55.240 --> 03:07:56.640] This system is gonna have bounded input, [03:07:56.640 --> 03:07:59.480] bounded output stability, which is really exciting, okay? [03:07:59.480 --> 03:08:02.760] So when we train, we can fix our A's, our representations, [03:08:02.760 --> 03:08:04.760] so that the eigenvalues satisfy this property, [03:08:04.760 --> 03:08:07.240] and that's gonna be one of the arts. [03:08:07.240 --> 03:08:08.840] Now, to implement this on a machine, [03:08:08.840 --> 03:08:10.240] we can't use continuous objects. [03:08:10.240 --> 03:08:12.120] We have to use them as discrete, [03:08:12.120 --> 03:08:14.920] and integrals are just big, smooth sums, basically. [03:08:14.920 --> 03:08:17.040] They're actually nicer to deal with than functions, [03:08:17.040 --> 03:08:18.880] and so what we'll do is we'll break that sum [03:08:18.880 --> 03:08:21.280] down into functions, and what happens in signal processing [03:08:21.280 --> 03:08:23.000] is you think that you're gonna sample [03:08:23.000 --> 03:08:25.000] at some regular frequency T, [03:08:25.000 --> 03:08:27.200] and then what I'm denoting here is X bracket K [03:08:27.200 --> 03:08:31.480] means the K-th sample, which is at the point KT, okay? [03:08:31.480 --> 03:08:32.560] So you're seeing this animation [03:08:32.560 --> 03:08:35.920] that the integral is just this nice, smooth sum, all right? [03:08:35.920 --> 03:08:37.480] Cool. [03:08:37.480 --> 03:08:39.480] All right, so now that we're in discrete land, [03:08:39.480 --> 03:08:42.600] we can relate it to more familiar machine learning concepts. [03:08:42.600 --> 03:08:44.160] The first thing is you can view this [03:08:44.160 --> 03:08:45.880] as a recurrence, as an RNN. [03:08:45.880 --> 03:08:47.680] So I'll introduce notation G here, [03:08:47.680 --> 03:08:49.440] which is basically the B times the input. [03:08:49.440 --> 03:08:52.320] It's all the modifications on the input that we had, [03:08:52.320 --> 03:08:54.460] and with just a little bit of arithmetic, [03:08:54.460 --> 03:08:57.040] I can move it out so that I get X of N plus one. [03:08:57.040 --> 03:09:00.240] The next hidden state is T times GN [03:09:00.240 --> 03:09:02.480] plus some term that's kind of down-weighting it, [03:09:02.480 --> 03:09:04.480] and I'm illustrating the down-weighting here [03:09:04.480 --> 03:09:05.560] in the visualization. [03:09:06.560 --> 03:09:10.560] RNNs are super fast, so if we did manage to learn the weights, [03:09:10.560 --> 03:09:13.320] the Bs, the As, all the rest of these things in the filter, [03:09:13.320 --> 03:09:15.680] then we could run this as an RNN automatically [03:09:15.680 --> 03:09:17.240] from the same parameterizations. [03:09:17.240 --> 03:09:18.080] Super cool. [03:09:18.080 --> 03:09:20.480] With just a little bit more notation, [03:09:20.480 --> 03:09:23.040] I can take that E term, the exponential there, [03:09:23.040 --> 03:09:26.600] and put that matrix exponential into this function F, [03:09:26.600 --> 03:09:28.080] and that becomes a convolution [03:09:28.080 --> 03:09:30.080] that's probably more familiar to most people, [03:09:30.080 --> 03:09:31.880] which is a discrete convolution. [03:09:31.880 --> 03:09:34.760] But notice this discrete convolution is of length N. [03:09:34.760 --> 03:09:36.360] It's a huge, long convolution. [03:09:36.360 --> 03:09:38.600] It's not a three-by-three convolution like a ResNet. [03:09:38.600 --> 03:09:40.840] It's actually as long, potentially, as the filter, [03:09:40.840 --> 03:09:42.560] so that's gonna be challenging to process, [03:09:42.560 --> 03:09:45.160] but this model says they're basically both the same. [03:09:45.160 --> 03:09:49.860] So the key technical challenge is to make these SSMs fast. [03:09:49.860 --> 03:09:52.400] Those long comms are hard. [03:09:52.400 --> 03:09:55.720] If you think about it, that F, that filter, is huge, [03:09:55.720 --> 03:09:57.680] and so if you materialize it at every time step, [03:09:57.680 --> 03:09:58.880] you'd be toast. [03:09:58.880 --> 03:10:00.320] It turns out that you don't ever have [03:10:00.320 --> 03:10:01.700] to materialize the hidden state. [03:10:01.700 --> 03:10:03.500] That's a really important observation. [03:10:03.500 --> 03:10:05.980] That allows you to go fast and allows you to have runtime [03:10:05.980 --> 03:10:07.860] that's proportional to the input and the output, [03:10:07.860 --> 03:10:09.440] not the massive hidden state. [03:10:09.440 --> 03:10:11.360] The hidden state is important for representation, [03:10:11.360 --> 03:10:14.280] but it's actually not important for implementation. [03:10:14.280 --> 03:10:15.240] You can check out the blog. [03:10:15.240 --> 03:10:18.480] The blog has more details about exactly how that works. [03:10:18.480 --> 03:10:20.600] The second thing which we spent a lot of time on, [03:10:20.600 --> 03:10:22.600] and Albert did a bunch of really brilliant things [03:10:22.600 --> 03:10:25.280] inspired by the Legendre Memory Units, [03:10:25.280 --> 03:10:27.880] how do we make that A have that nice eigenvalue structure [03:10:27.880 --> 03:10:29.400] so that we know it's stable? [03:10:29.400 --> 03:10:31.880] Things like diagonal matrices are really easy [03:10:31.880 --> 03:10:35.240] to keep this structure because they're scalars. [03:10:35.240 --> 03:10:37.020] On the diagonal, you can keep it. [03:10:37.020 --> 03:10:38.920] But computing matrix exponentials in general [03:10:38.920 --> 03:10:41.440] for expressive classes is actually pretty challenging. [03:10:41.440 --> 03:10:44.120] So we had to do a ton of work to get that to happen [03:10:44.120 --> 03:10:45.880] over the last couple of years. [03:10:45.880 --> 03:10:47.620] And the last bit is this practical, [03:10:47.620 --> 03:10:49.480] fast convolution that we needed. [03:10:49.480 --> 03:10:51.560] Now, I love this slide because Dolly 3 [03:10:51.560 --> 03:10:52.840] made most of the art in this talk, [03:10:52.840 --> 03:10:54.000] or all of the art in this talk, [03:10:54.000 --> 03:10:55.080] and it made this poster. [03:10:55.080 --> 03:10:56.080] I didn't give it the tagline. [03:10:56.080 --> 03:10:57.160] I still think it's hysterical. [03:10:57.160 --> 03:10:59.280] Too fast, too furious, revving up the equations. [03:10:59.280 --> 03:11:01.120] I have no idea what that means, but I love it. [03:11:01.120 --> 03:11:02.360] Supposed to be furrier, by the way. [03:11:02.360 --> 03:11:03.760] That's the thing. [03:11:03.760 --> 03:11:06.000] Any case, the thing is is we had to do [03:11:06.000 --> 03:11:09.200] the same type of operation that we did in FlashAttention, [03:11:09.200 --> 03:11:11.480] but now on FFTs and convolutions. [03:11:11.480 --> 03:11:15.120] If you naively run FFTs, you have terrible memory behavior. [03:11:15.120 --> 03:11:17.400] If you can somehow group them together in nice ways [03:11:17.400 --> 03:11:19.260] and be I/O aware, you can get back [03:11:19.260 --> 03:11:21.360] to that kind of nice utilization. [03:11:21.360 --> 03:11:24.440] FlashAttention, if you recall, was about 72% utilization. [03:11:24.440 --> 03:11:27.240] Dan and Herman got to 65% utilization. [03:11:27.240 --> 03:11:30.260] I would also say that Dan's on the faculty market this year [03:11:30.260 --> 03:11:31.640] and Herman's on the PhD market, [03:11:31.640 --> 03:11:32.760] and you'd be smart to hire them. [03:11:32.760 --> 03:11:34.160] They're amazing. [03:11:34.160 --> 03:11:36.480] So the point is is there's not really a hardware trade-off [03:11:36.480 --> 03:11:38.320] after you do a bunch of work. [03:11:38.320 --> 03:11:39.560] It's really algorithmic. [03:11:39.560 --> 03:11:42.440] This thing is gonna do a lot fewer operations. [03:11:42.440 --> 03:11:44.040] And this led to what some folks have called, [03:11:44.040 --> 03:11:45.560] Sasha called, an RNN renaissance. [03:11:45.560 --> 03:11:47.780] And I wanna say it's been super fun. [03:11:47.780 --> 03:11:49.360] I have to say the last year and a half [03:11:49.360 --> 03:11:51.960] of two years of research, I've absolutely loved [03:11:51.960 --> 03:11:54.340] because you had a ton of people contributing amazing ideas [03:11:54.340 --> 03:11:58.560] like S5 and Mega and RWKV on super technical topics [03:11:58.560 --> 03:12:00.680] that were really exciting for us to do. [03:12:00.680 --> 03:12:02.940] And there's just been so many more that I can't put on here [03:12:02.940 --> 03:12:05.600] and they've been pushing the state of the art. [03:12:05.600 --> 03:12:07.560] So now you've listened to my talk and you're like, [03:12:07.560 --> 03:12:09.520] should we use these models everywhere? [03:12:09.520 --> 03:12:12.320] And maybe I'm a California optimist, so I sound happy. [03:12:12.320 --> 03:12:14.400] I know it's irritating, but I'm happy about everything. [03:12:14.400 --> 03:12:15.360] So I am. [03:12:15.360 --> 03:12:16.920] So you're like, maybe you should use these things. [03:12:16.920 --> 03:12:19.120] I say, well, maybe. [03:12:19.120 --> 03:12:21.560] But there's actually a pretty big gap on language. [03:12:21.560 --> 03:12:24.340] So it was wonderful on LRA and those signal processing tasks, [03:12:24.340 --> 03:12:26.300] but when we actually deployed it on language, [03:12:26.300 --> 03:12:27.140] there was a gap. [03:12:28.140 --> 03:12:32.820] Now, the standard way you measure a language model [03:12:32.820 --> 03:12:33.740] is perplexity. [03:12:33.740 --> 03:12:36.280] This is the score of how predictable the language is. [03:12:36.280 --> 03:12:38.160] To give you a sense of this measure, [03:12:38.160 --> 03:12:42.100] S4 was five points worse on perplexity versus transformers. [03:12:42.100 --> 03:12:44.340] And that's a staggering number because five points [03:12:44.340 --> 03:12:47.200] is about the difference between 125 million parameter model [03:12:47.200 --> 03:12:48.820] and a seven billion parameter model. [03:12:48.820 --> 03:12:50.020] It was a big gap. [03:12:50.020 --> 03:12:52.280] So we started to wonder, why is that? [03:12:52.280 --> 03:12:54.220] So we went back to work that other folks had done, [03:12:54.220 --> 03:12:57.040] which was amazing, this associative recall task. [03:12:57.040 --> 03:12:59.760] So the task here is I give you letters and numbers. [03:12:59.760 --> 03:13:02.140] The last letter is a query, in this case, C. [03:13:02.140 --> 03:13:03.540] And you have to tell me which number [03:13:03.540 --> 03:13:04.820] is associated with that letter. [03:13:04.820 --> 03:13:06.340] It's a lookup task. [03:13:06.340 --> 03:13:08.940] Attention can crush this because it's a very easy [03:13:08.940 --> 03:13:10.200] lookup task. [03:13:10.200 --> 03:13:12.100] These two variants of S4 that came out later [03:13:12.100 --> 03:13:14.580] that are supposed to be better on language were better, [03:13:14.580 --> 03:13:16.780] but there was a gap here, too. [03:13:16.780 --> 03:13:18.540] And so without going into too much detail on this piece, [03:13:18.540 --> 03:13:21.140] Michael Pauly came along and did this thing, Hyena, [03:13:21.140 --> 03:13:25.460] and he showed he could get 100% on this underlying operator [03:13:25.460 --> 03:13:27.360] and did it in a very exciting way [03:13:27.360 --> 03:13:29.720] while still maintaining speed and all the rest. [03:13:29.720 --> 03:13:32.060] So this is what the picture looked like [03:13:32.060 --> 03:13:34.800] as of a couple of months ago, or a couple of weeks ago, [03:13:34.800 --> 03:13:36.540] I guess, two weeks ago. [03:13:36.540 --> 03:13:38.560] You had S4, which was a bit worse, [03:13:38.560 --> 03:13:41.040] but then in quick succession, people were coming down [03:13:41.040 --> 03:13:42.760] to this very strong attention baseline. [03:13:42.760 --> 03:13:44.000] All the baselines are released. [03:13:44.000 --> 03:13:45.320] Eleuther made a wonderful harness. [03:13:45.320 --> 03:13:47.200] These are all at 350 million. [03:13:47.200 --> 03:13:48.880] You can start to play with these things. [03:13:48.880 --> 03:13:51.840] And RWKV has been releasing even bigger models. [03:13:51.840 --> 03:13:53.820] And so there was this baseline here. [03:13:53.820 --> 03:13:56.120] These are closing the gap without attention. [03:13:56.120 --> 03:13:57.560] But part of the reason I love academia [03:13:57.560 --> 03:13:59.440] is you can worry about tiny problems. [03:13:59.440 --> 03:14:01.280] It's like, well, it seems like a small problem, [03:14:01.280 --> 03:14:02.720] but why is it worse? [03:14:02.720 --> 03:14:04.760] And so we kept asking, we kept poking at it, [03:14:04.760 --> 03:14:06.600] and Simran and Sabree came in, [03:14:06.600 --> 03:14:08.400] and they actually came up with this idea. [03:14:08.400 --> 03:14:10.000] It took us a surprising amount of time, [03:14:10.000 --> 03:14:11.680] but it was just a small twist. [03:14:11.680 --> 03:14:13.360] The small twist is what a transformer can do [03:14:13.360 --> 03:14:15.600] is not one lookup, but many lookups. [03:14:15.600 --> 03:14:18.000] So what MQAR is is multi-queries. [03:14:18.000 --> 03:14:20.580] We don't just look up one letter, we look up many letters. [03:14:20.580 --> 03:14:22.520] Now we can worry about scaling in the letters, [03:14:22.520 --> 03:14:24.800] the vocab size, the model dimension. [03:14:24.800 --> 03:14:26.560] And what we found is that all of these models [03:14:26.560 --> 03:14:29.400] can, quote, solve the task, but how they do it, [03:14:29.400 --> 03:14:31.200] their scaling is quite different. [03:14:31.200 --> 03:14:32.400] And this relates to a bunch of things [03:14:32.400 --> 03:14:34.360] in parallel circuit complexity that I won't get to, [03:14:34.360 --> 03:14:35.880] but this is a really interesting thing [03:14:35.880 --> 03:14:38.200] where we can start to study the scaling. [03:14:38.200 --> 03:14:39.660] And so what they realized is that attention [03:14:39.660 --> 03:14:42.120] can solve these things with a small number of dimensions, [03:14:42.120 --> 03:14:45.760] roughly logarithmic, whereas Hyena and RWKV require, [03:14:45.760 --> 03:14:47.160] and all the convolutional models, [03:14:47.160 --> 03:14:48.720] as a result of their reduction, [03:14:48.720 --> 03:14:50.200] require things, model dimensions, [03:14:50.200 --> 03:14:52.000] that scale with the sequence length. [03:14:52.840 --> 03:14:54.260] And so you get charts that look like this. [03:14:54.260 --> 03:14:57.160] They'll solve it, but they need more capacity to do so. [03:14:57.160 --> 03:15:01.200] So when we started looking at these MQAR things in the wild, [03:15:01.200 --> 03:15:04.320] we started thinking, well, okay, MQAR is a nice synthetic, [03:15:04.320 --> 03:15:05.160] but does it translate? [03:15:05.160 --> 03:15:06.100] And this was really insightful. [03:15:06.100 --> 03:15:07.400] Simran and Sabree did this. [03:15:07.400 --> 03:15:08.760] They said, we're gonna take the pile, [03:15:08.760 --> 03:15:11.760] and we're gonna segment out which ones are AR-like, [03:15:11.760 --> 03:15:13.040] which sentences are AR-like. [03:15:13.040 --> 03:15:15.120] So these are things that have repeated bigrams. [03:15:15.120 --> 03:15:16.920] Common buzzard is repeated twice. [03:15:16.920 --> 03:15:18.280] There's kind of an implicit lookup [03:15:18.280 --> 03:15:21.280] the second time you're doing the common buzzard task. [03:15:21.280 --> 03:15:23.360] That's about 7% of the pile. [03:15:23.360 --> 03:15:26.480] The non-AR slice was basically everything else. [03:15:26.480 --> 03:15:29.800] What they found is that the attention gap, 82% of it, [03:15:29.800 --> 03:15:32.640] was explained, even though this is a pretty rough proxy [03:15:32.640 --> 03:15:35.120] for the task, by just what's going on here. [03:15:35.120 --> 03:15:37.400] And this made us think, maybe if we solve this task, [03:15:37.400 --> 03:15:38.820] we can even close it. [03:15:38.820 --> 03:15:40.480] But the other observation was, [03:15:40.480 --> 03:15:43.620] actually these convolutional models are slightly better [03:15:43.620 --> 03:15:45.000] on the non-lookup task. [03:15:45.000 --> 03:15:47.420] So maybe there's hope to go beyond them. [03:15:47.420 --> 03:15:49.080] And so we started this kind of architecture, [03:15:49.080 --> 03:15:51.040] and I wanna give another shout out here to a paper I love. [03:15:51.040 --> 03:15:53.560] I love the T5 paper, I'm sure many of you do too. [03:15:53.560 --> 03:15:54.840] I love the vibe of it where it's like, [03:15:54.840 --> 03:15:57.200] hey, we just wanna say what are the common elements [03:15:57.200 --> 03:15:58.080] that are going on. [03:15:58.080 --> 03:16:00.440] If you're outside this little tiny sub-community, [03:16:00.440 --> 03:16:02.480] all the papers look very, very different. [03:16:02.480 --> 03:16:04.700] But if you're inside, I would say there's a couple [03:16:04.700 --> 03:16:05.900] of really common themes. [03:16:05.900 --> 03:16:07.760] And Simran and Sabree tried to boil them down [03:16:07.760 --> 03:16:09.660] so that more folks can participate [03:16:09.660 --> 03:16:12.000] and come into the field in a more easy way. [03:16:12.000 --> 03:16:14.320] The themes are long convolutions, [03:16:14.320 --> 03:16:16.280] convolutions that are scaling with the input, [03:16:16.280 --> 03:16:18.360] not necessarily the full input size. [03:16:18.360 --> 03:16:20.080] Gating is a wonderful idea. [03:16:20.080 --> 03:16:22.160] That's multiplying in this kind of component-wise way [03:16:22.160 --> 03:16:24.200] in the sequence, that's an old idea. [03:16:24.200 --> 03:16:25.160] And data dependence. [03:16:25.160 --> 03:16:27.260] And Mamba just came out from Albert and Tree, [03:16:27.260 --> 03:16:30.440] which did this and still kept that sub-quadratic runtime. [03:16:30.440 --> 03:16:33.320] Based is basically just simplifying all of the things [03:16:33.320 --> 03:16:35.440] that people are doing and trying to get to something nice. [03:16:35.440 --> 03:16:37.680] We don't have T5 level niceness yet, [03:16:37.680 --> 03:16:39.660] but we are inspired by that. [03:16:39.660 --> 03:16:41.100] One thing I wanna point out is that [03:16:41.100 --> 03:16:45.000] this new convolutional architecture does scale for MQAR [03:16:45.000 --> 03:16:46.200] a little bit like attention. [03:16:46.200 --> 03:16:48.120] So it has the same kind of dimension scaling [03:16:48.120 --> 03:16:50.120] that the others had, which is interesting. [03:16:50.120 --> 03:16:52.920] So the point is, is very recently, [03:16:52.920 --> 03:16:55.340] this is in the last week run up to NeurIPS, [03:16:55.340 --> 03:16:57.680] both Mamba and Based, and I'm sure five others [03:16:57.680 --> 03:16:59.460] will come out in the next couple of weeks, [03:16:59.460 --> 03:17:01.560] are now attention-free and actually getting you [03:17:01.560 --> 03:17:03.780] lower PPL at 350. [03:17:03.780 --> 03:17:05.200] Doesn't mean they're gonna get you lower PPL [03:17:05.200 --> 03:17:08.560] necessarily at 100 billion, but it's interesting to say [03:17:08.560 --> 03:17:11.400] there doesn't seem to be any fundamental kind of block, [03:17:11.400 --> 03:17:13.300] and that's, to me, extremely exciting. [03:17:13.300 --> 03:17:16.520] I did wanna point out a little bit that [03:17:16.520 --> 03:17:18.700] there is another bottleneck that's lurking [03:17:18.700 --> 03:17:20.720] for truly sub-quadratic models. [03:17:20.720 --> 03:17:22.880] We talked a lot about the signal processing part, [03:17:22.880 --> 03:17:25.040] but there's also this MLPs, and I've become obsessed [03:17:25.040 --> 03:17:26.840] with them, there's a whole line of work, [03:17:26.840 --> 03:17:29.280] check out Dan Fu's talk about trying to understand [03:17:29.280 --> 03:17:32.000] what's going on with the MLPs and can we slim those down. [03:17:32.000 --> 03:17:35.560] They become a bottleneck at much larger dimension sizes. [03:17:35.560 --> 03:17:37.020] So the questions that were driving our work [03:17:37.020 --> 03:17:37.860] really were threefold. [03:17:37.860 --> 03:17:39.120] I shared with you, I hope, a little bit [03:17:39.120 --> 03:17:40.920] about how foundation models change the systems [03:17:40.920 --> 03:17:42.080] that we're building. [03:17:42.080 --> 03:17:44.480] I also talked a lot about how classical ideas [03:17:44.480 --> 03:17:46.320] from signal processing and databases [03:17:46.320 --> 03:17:48.920] were interesting bits of canon to bring into the field [03:17:48.920 --> 03:17:51.800] so that maybe we can make these models more efficient. [03:17:51.800 --> 03:17:54.040] What I thought I would end with is just why I think [03:17:54.040 --> 03:17:55.600] there's such a bright future in AI [03:17:55.600 --> 03:17:57.640] for two minutes and systems. [03:17:57.640 --> 03:18:00.280] The first thing is, we weren't using these models [03:18:00.280 --> 03:18:03.320] really 15 to 18 months ago in the way we're using them now. [03:18:03.320 --> 03:18:05.280] We knew intuitively that you train them once [03:18:05.280 --> 03:18:07.680] and use them multiple times, but it's not really clear [03:18:07.680 --> 03:18:09.340] we were doing that, we were kinda just showing them [03:18:09.340 --> 03:18:11.200] to each other, if we're honest. [03:18:11.200 --> 03:18:13.840] Now, people are using them on a daily basis, [03:18:13.840 --> 03:18:16.200] and inference has become an unbelievable task, [03:18:16.200 --> 03:18:18.480] I would say even the last three or four months, [03:18:18.480 --> 03:18:20.160] the speed of inference, if you watch [03:18:20.160 --> 03:18:21.720] on a bunch of the commercial servers, [03:18:21.720 --> 03:18:24.360] are just going through the roof as new ideas come in. [03:18:24.360 --> 03:18:25.760] Of course, people were thinking about this, [03:18:25.760 --> 03:18:28.480] MQA and GQA a while ago, speculative decoding [03:18:28.480 --> 03:18:31.120] was an amazing paper, VLM was really exciting, [03:18:31.120 --> 03:18:34.440] FlashDecode, MatFormer, there's a ton of exciting work here. [03:18:34.440 --> 03:18:37.500] My point is, this really kicked off like six months ago. [03:18:37.500 --> 03:18:40.440] Wild to think about, but that's the whole thing. [03:18:40.440 --> 03:18:42.600] Another bit is, there's a big difference [03:18:42.600 --> 03:18:45.600] between low latency systems and high throughput systems. [03:18:45.600 --> 03:18:47.000] When you don't care if it returns [03:18:47.000 --> 03:18:48.920] in a couple of milliseconds, but you wanna, say, [03:18:48.920 --> 03:18:50.440] run on a hundred different documents, [03:18:50.440 --> 03:18:52.120] or a million different documents. [03:18:52.120 --> 03:18:55.440] We're just at the outset of seeing that systems pitch, [03:18:55.440 --> 03:18:57.400] as people are actually using these foundation models [03:18:57.400 --> 03:18:59.760] on all the back of house data cleaning-ish tasks [03:18:59.760 --> 03:19:01.760] that I think are gonna happen in the next while. [03:19:01.760 --> 03:19:04.400] There's new data types, I do wanna call out [03:19:04.400 --> 03:19:06.320] that there's all kinds of things you could worry about [03:19:06.320 --> 03:19:08.440] from Koonle, about how to program these systems, [03:19:08.440 --> 03:19:09.980] what's the right accelerators and hardware, [03:19:09.980 --> 03:19:12.480] that's just happening, what are the right systems to build [03:19:12.480 --> 03:19:14.520] that are systems of record underneath the covers, [03:19:14.520 --> 03:19:16.160] there's tons of stuff. [03:19:16.160 --> 03:19:19.000] - Yep, I gave Chris a little bit more time there, [03:19:19.000 --> 03:19:20.880] because he's such a legend, and he covers [03:19:20.880 --> 03:19:24.080] so many different concepts and updates and models [03:19:24.080 --> 03:19:25.520] in such a small amount of time, [03:19:25.520 --> 03:19:28.160] so his time is very high quality, [03:19:28.160 --> 03:19:29.680] and you should watch the whole talk [03:19:29.680 --> 03:19:31.300] if you get the opportunity. [03:19:31.300 --> 03:19:34.080] But that's it for our coverage of Neurope 2023, [03:19:34.080 --> 03:19:37.400] it's just a ton of papers, we are gonna follow up [03:19:37.400 --> 03:19:40.640] with a lot of the startups that I encountered and met, [03:19:40.640 --> 03:19:42.400] a lot of which are returning guests, [03:19:42.400 --> 03:19:45.200] so keep a lookout for that, but also thank you so much [03:19:45.200 --> 03:19:48.320] for listening in on this, it's an experimental new format, [03:19:48.320 --> 03:19:51.400] we grabbed a whole bunch of audio, spliced in, [03:19:51.400 --> 03:19:55.220] live interviews, stage talks, and some of my own commentary [03:19:55.220 --> 03:19:57.200] with a little bit of backing music, [03:19:57.200 --> 03:20:00.200] it's an experimental new thing, did you like it? [03:20:00.200 --> 03:20:02.980] Let us know, if you liked it, then share it with a friend, [03:20:02.980 --> 03:20:05.580] that would help us a lot, and also just remember, [03:20:05.580 --> 03:20:06.840] we have a listener survey going on, [03:20:06.840 --> 03:20:10.060] so please come to our website and fill out our survey. [03:20:10.060 --> 03:20:12.200] - Thanks, and see you at the next Neurope's Recap. [03:20:12.200 --> 03:20:13.320] DJ QD outro. [03:20:13.320 --> 03:20:15.900] (upbeat music) [03:20:15.900 --> 03:20:18.480] (upbeat music) [03:20:19.080 --> 03:20:21.660] (upbeat music) [03:20:21.660 --> 03:20:24.240] (upbeat music) [03:20:24.240 --> 03:20:25.980] (upbeat music)