back to index

Stanford XCS224U: NLU I Fantastic Language Models and How to Build Them, Part 2 I Spring 2023


Whisper Transcript | Transcript Only Page

00:00:00.000 | All right. Welcome everyone.
00:00:07.400 | Again, we have a very full day.
00:00:09.400 | The plan is to finish up our review of core information retrieval stuff.
00:00:15.280 | The focus will be on neural information retrieval,
00:00:18.160 | where a lot of the action is these days.
00:00:20.840 | I have then a few datasets to show you,
00:00:23.240 | and then I'm going to turn it over to Sid,
00:00:25.240 | and Sid is going to help us talk again about how to
00:00:28.400 | build fantastic language models.
00:00:31.200 | So let's dive in. We'll start by using our big handout here,
00:00:36.520 | information retrieval.
00:00:38.400 | Right. So here we are,
00:00:40.040 | and we are going to skip.
00:00:41.520 | That's right. I had a couple more metrics that I wanted to show you.
00:00:44.760 | So let's start there.
00:00:45.680 | So last time we talked about how assessment in the space of IR should be multidimensional.
00:00:52.800 | We've been focused on accuracy,
00:00:54.920 | but I will make amends.
00:00:56.080 | We are going to circle back and talk about these other dimensions,
00:00:59.200 | which I regard as absolutely crucial in this space.
00:01:03.040 | But with that said, we did dive into different metrics.
00:01:06.680 | We talked about success and reciprocal rank.
00:01:10.120 | Success, you should think of as just saying,
00:01:12.840 | for my chosen k,
00:01:14.760 | is there a star above me?
00:01:16.720 | That is, is there a relevant document above k?
00:01:20.220 | So it's a very coarse-grained measure.
00:01:22.480 | So this one here,
00:01:23.800 | if we set success at 2, D1,
00:01:26.560 | has success because there is a star at 2 or above.
00:01:30.240 | D2, that ranking also has a success of 1 because there is a star at 2 or above,
00:01:36.160 | and poor D3 gets a success of 0.
00:01:39.760 | And you can see already that it's coarse-grained because D1 and D2 are differentiated,
00:01:45.320 | in some intuitive sense,
00:01:46.600 | but here they both got a success score of 1.
00:01:50.140 | Reciprocal rank is a little bit better in the sense that it's more or less just
00:01:55.940 | registering whether there's a star at or above k,
00:01:58.900 | except now we are sensitive to the top most ranked one.
00:02:03.020 | So for example, D1 here has an RR at 2 of 1 because there is a star in first place.
00:02:09.820 | Whereas D2 has 1 over 2 because the first star is in second place.
00:02:16.500 | And then D3 still gets its poor 0.
00:02:19.940 | So pretty coarse-grained but very intuitive and sometimes success and RR are
00:02:25.500 | good metrics in the sense that you kind of just want to know for
00:02:28.620 | your chosen k whether you hit the mark,
00:02:31.220 | whether you got a star.
00:02:32.380 | And especially if you only have one relevant document per query,
00:02:36.140 | you might as well use these metrics.
00:02:38.500 | And then RR will just be a little bit more nuanced.
00:02:41.780 | We also talked about precision and recall,
00:02:44.180 | the classic accuracy style metrics in this space.
00:02:47.960 | The differentiator here from the previous ones is that
00:02:50.960 | these are going to be sensitive to multiple stars.
00:02:53.620 | So if you have more than one document that's relevant to your query,
00:02:57.900 | you will be able to detect that.
00:03:00.440 | So we have this notion of a return value that is just the set of documents k or above.
00:03:05.700 | And then the relevant documents,
00:03:07.500 | those are the ones with stars.
00:03:09.340 | And precision is saying for my chosen k,
00:03:12.540 | what percentage of the things at or above k are relevant?
00:03:17.260 | And that's precision in the sense that if you picked k,
00:03:20.380 | you're looking at the set of documents and you want to
00:03:22.540 | know how many of them have stars relative to the total.
00:03:25.900 | Or like the reverse of precision would be like which ones
00:03:29.340 | are kind of imprecise as predictions because there's no star there.
00:03:33.060 | And then recall is kind of the dual of that and it says for my chosen k,
00:03:37.920 | how many of the stars made it up to k or above?
00:03:41.940 | And the opposite of that would be like how many stars are lingering down below?
00:03:46.940 | So you can see here because of the numerator that we're going to differentiate
00:03:50.740 | systems now based on how many stars are at k or above.
00:03:54.820 | So it's sensitive to multiple stars.
00:03:57.460 | So just to walk through again,
00:03:59.440 | precision at 2 for D1 is 2 out of 2.
00:04:02.660 | For D2, it's 1 out of 2 because just half of them have a star.
00:04:08.060 | And for poor D3, 0 out of 2.
00:04:11.820 | Recall is very similar but now the denominator changes, right?
00:04:15.140 | So the recall at 2 for this first one is 2 out of 3.
00:04:18.420 | That is of the three-star documents,
00:04:20.520 | 2 are at k or above.
00:04:22.500 | Here it's 1 out of 3 and here at 0 out of 3.
00:04:26.500 | And just to round this out,
00:04:28.420 | poor D3 has not fared well in our ranking so far.
00:04:32.980 | But in a surprise twist,
00:04:35.040 | if I change the value of k to 5,
00:04:37.820 | all of a sudden D3 looks pretty good.
00:04:40.860 | Because now it's got all three of its stars at 5 or above.
00:04:45.180 | Whereas the other two,
00:04:46.860 | even though they've got some high stars up there,
00:04:49.560 | we're not sensitive to that precisely.
00:04:52.020 | And so now D3 has pulled ahead.
00:04:54.820 | And that is maybe something that you want to watch out for because people kind of
00:04:58.580 | innocently choose these k values when they're evaluating systems.
00:05:01.580 | And I just showed you that that could really impact the ranking of systems.
00:05:06.340 | And in particular, like,
00:05:09.180 | you know, it's hard to imagine since there are only six documents.
00:05:12.060 | But if it was a lot of work to travel down to our chosen k,
00:05:15.300 | if k was 1,000,
00:05:17.300 | this would obscure the fact that we might pick as our winner
00:05:20.960 | a system that had all the stars more or less at 1,000.
00:05:24.500 | And the other systems which have their stars at the top of this ranking,
00:05:28.020 | and therefore they're easy to find,
00:05:29.740 | those might be diminished with such a high k.
00:05:33.320 | And so that kind of gets you into the role of thinking,
00:05:35.740 | what are my users trying to do?
00:05:37.140 | What is the cost of them scanning down a list of ranked results and things like that?
00:05:41.020 | And that's where I want you to be when you think about these metrics.
00:05:44.660 | What are you trying to solve out there in the world?
00:05:47.740 | What are your users confronting?
00:05:49.200 | What is the cost of reviewing examples and so forth and so on? Yeah.
00:05:54.300 | Well, the neural IR models that we're going to kind of solve this problem of, right,
00:05:59.340 | because right now everything's based on the presence or not of a word,
00:06:03.500 | rather than kind of maybe a- either a longer meaning or,
00:06:07.540 | um, like the quality of the relevance, however we define it.
00:06:11.140 | Like maybe it only says the word once but actually has the best information afterwards.
00:06:15.120 | Will that take care of that or is- is all of
00:06:17.420 | neural also going to be based on presence or not of words?
00:06:20.500 | That's a great question. Ah, wait, we should be careful.
00:06:23.060 | So yeah, I think for the first part of your question,
00:06:25.740 | I want to say the neural IR models are overall going to be better.
00:06:29.540 | Because of what you alluded to,
00:06:31.380 | they have a very rich semantic space.
00:06:34.140 | It won't directly impact this because these stars after all aren't about terms.
00:06:39.100 | This is about whether a whole document was relevant to a query.
00:06:41.980 | You should imagine that the background process is like some team of humans went through and said,
00:06:46.900 | okay, you searched for Bert and now I'm going through documents and saying,
00:06:51.380 | yeah, this one is relevant, this one isn't.
00:06:53.480 | That's what produced these rankings.
00:06:55.900 | But I think you're right in your core intuition.
00:06:58.500 | Term-based models are going to be kind of brittle.
00:07:01.260 | And if we have hard query document pairs, they might miss them.
00:07:06.020 | Actually, that reminds me like this,
00:07:08.020 | for some reason it didn't display before.
00:07:09.700 | Let's see if it displays now.
00:07:10.860 | I had this nice example that Omar created.
00:07:13.580 | This is an example of why search is a hard NLU problem.
00:07:17.880 | Because this is a query,
00:07:19.540 | what compounds protect the digestive system against viruses,
00:07:23.020 | where the response is certainly relevant,
00:07:25.540 | but there is zero relevant term overlap between query and document.
00:07:29.860 | All of the connections that we want to make are deeply semantic connections.
00:07:34.620 | And I do think that that is why neural IR models have pulled ahead for
00:07:40.100 | accuracy style assessments trying to be careful as you'll see.
00:07:45.300 | I have one more metric which is average precision.
00:07:50.860 | This will be, I think this is fair to say,
00:07:53.300 | our most nuanced metric.
00:07:55.120 | Okay. So a little bit hard to think about,
00:07:57.260 | but I think it's intuitive.
00:07:58.500 | Average precision, notice it has no K. And the reason it has no K is that we're going
00:08:04.540 | to sum over all the precision values for
00:08:07.200 | different Ks here where there is a relevant document.
00:08:11.140 | Think back to our rankings.
00:08:12.560 | Wherever there was a star,
00:08:14.200 | we're going to choose that as a K. And we're going to sum up
00:08:17.780 | just those precision values and divide it by the number of relevant documents.
00:08:23.800 | Here's an example. Same three rankings that we had before,
00:08:27.700 | and what I'll show you are these precision calculations.
00:08:30.620 | So for the first one,
00:08:32.040 | we have stars at position one, two, and six.
00:08:35.380 | And so we accumulate the precision values for one, two, and six.
00:08:39.540 | And those are, I hope,
00:08:40.940 | the ones I've given there.
00:08:42.260 | That sums to 2.5,
00:08:44.700 | and then we divide that by three,
00:08:46.540 | which is the number of relevant documents.
00:08:49.660 | So we've abstracted away the K,
00:08:52.900 | which is reassuring, and we're also checking at every level.
00:08:55.780 | So it's not going to have that sensitivity I showed you before where
00:08:59.500 | the choice of K dramatically impacts which rankings we
00:09:02.980 | favor because now we're kind of looking at all of the ones chosen by the ranking.
00:09:07.900 | So that's D1, and then for D2, same thing.
00:09:11.320 | But now we're checking at two, five,
00:09:12.980 | and six because that's where the stars are,
00:09:15.100 | and that sums to 1.4.
00:09:17.500 | And then for D3,
00:09:19.500 | we do the same thing at positions three, four,
00:09:21.780 | and five, and notice interestingly that D3 has a pulled ahead of D2.
00:09:28.340 | That's less surprising to me in the current context because D2 is kind of good and kind of not.
00:09:37.860 | It has that one star that's near the top,
00:09:40.220 | but the other two stars are way at the bottom of our ranking,
00:09:43.500 | whereas at least D3 kind of put them all at least not literally at the bottom.
00:09:49.460 | Whereas D1 looks like just a slam dunk winner here.
00:09:53.300 | I mean, it simply has most of the stars right at the top.
00:09:56.940 | It has that one lonely one at the bottom,
00:09:59.300 | but you know on balance D1 looks good.
00:10:01.540 | So if I just stepped back from this little example,
00:10:04.620 | I would say that average precision is kind of nice in terms of giving what looks to me like
00:10:09.700 | a pretty nuanced picture of these three rankings.
00:10:13.500 | I think that's all of the accuracy style metrics.
00:10:19.260 | Of course, there are others that you'll encounter.
00:10:21.100 | Some are sensitive to the numerical like the float value,
00:10:24.620 | because sometimes you have not just a one or a zero,
00:10:26.660 | a star or not, but rather a float value for relevance.
00:10:29.860 | There are lots of versions that of course average these over sets of queries.
00:10:34.060 | That'll be very common to see,
00:10:35.740 | but underlyingly that's just some kind of arithmetic average of these scores.
00:10:40.460 | So I think this is a good sample.
00:10:42.000 | Are there questions I can answer about these metrics?
00:10:46.220 | Really great.
00:10:49.100 | Yeah.
00:10:49.420 | The float value one for relevance,
00:10:52.020 | how's that at a high level computed?
00:10:54.980 | What's it called? Like the discounted cumulative gain,
00:10:59.220 | and it is the sum of all of the scores divided by something or other.
00:11:05.900 | This must be in your history somewhere.
00:11:08.340 | It is something.
00:11:09.220 | [LAUGHTER]
00:11:10.220 | You could also just label,
00:11:12.260 | you know, have human labels and then you can take the precision,
00:11:17.060 | or not precision, maybe just position-weighted combination of human labels.
00:11:21.660 | [inaudible]
00:11:24.780 | So you found the discounted cumulative gain.
00:11:27.420 | That's a metric I left out.
00:11:29.220 | And then you're just observing that very often for these datasets,
00:11:32.860 | we'd have humans do a bunch of labeling,
00:11:34.820 | and then average precision is one way of aggregating over the labels we might have collected.
00:11:41.180 | I kind of alluded to this before,
00:11:45.860 | but like here's a partial list of things you could think about.
00:11:48.780 | Which metric? Fundamentally, there is no single answer.
00:11:52.060 | Is the cost of sc- scrolling through k passages low?
00:11:56.340 | Then maybe success at k is fine.
00:11:58.540 | Because you don't care whether it was a position nine or position one,
00:12:02.060 | what you really care about is that the user is kind of
00:12:04.660 | confronted with the success that they can easily find.
00:12:08.260 | That's one scenario that you could be in.
00:12:10.780 | Are there multiple relevant documents per query?
00:12:13.980 | This is straightforward. If so,
00:12:15.540 | you probably shouldn't use success at k or rrk,
00:12:19.580 | because they're only sensitive really to one star.
00:12:22.620 | And if you went to the trouble of getting multiple stars,
00:12:25.660 | you know, why have your metric be insensitive to that?
00:12:28.580 | So that seems clear. Is it more important to find every relevant document?
00:12:33.500 | If so, you should favor recall.
00:12:35.100 | That would be a case where maybe human review is cheap,
00:12:38.460 | or the cost of missing an example is hugely expensive.
00:12:43.020 | In that case, you want to favor recall.
00:12:44.980 | You can't miss anything.
00:12:46.220 | Conversely, if you just need to find some relevant things,
00:12:50.620 | maybe in an ocean of examples,
00:12:52.700 | because you want to label them,
00:12:54.140 | or because it's just good to know about them,
00:12:56.300 | then you could favor precision.
00:12:57.820 | Because then all you really care about is that near the top are some relevant things.
00:13:03.740 | F1 at k is the harmonic mean of precision and recall.
00:13:07.860 | Same thing as we do in NLP.
00:13:10.340 | And that can be used where there are multiple relevant documents,
00:13:13.260 | but maybe the relative order above k doesn't matter.
00:13:16.900 | That's just one perspective on what I mean when I say we're combining precision and recall.
00:13:22.500 | And then finally, average precision of the ones I've showed you,
00:13:26.140 | will give you the most fine-grained distinctions of the metrics, right?
00:13:29.340 | Because it's sensitive to rank,
00:13:31.260 | and it's sensitive to precision and recall.
00:13:33.820 | Precision because it aggregates over those values,
00:13:36.060 | and recall because that's the denominator.
00:13:38.660 | So that looks like an awfully good way to really get a fine-grained ranking of systems.
00:13:45.020 | And then finally, I'm going to talk about this a bit later.
00:13:50.220 | We have to move on beyond accuracy.
00:13:52.260 | This is a paper that I did with a team recently of researchers here and at IBM.
00:13:56.780 | What we're seeing here is a kind of post-hoc leaderboard.
00:14:00.460 | Not an actual leaderboard because part of
00:14:02.860 | our complaint is that there are no leaderboards that
00:14:05.180 | really do anything beyond measuring accuracy style things.
00:14:08.660 | But if you did go through the literature as we did here and find a lot of systems,
00:14:13.900 | you can see that they vary widely along other dimensions.
00:14:17.660 | Here is the mean reciprocal rank,
00:14:19.740 | one of our rankings, goes from 19 to 37 or 39 or something.
00:14:24.740 | So you say, okay, but then look just to the right of that at the query latency.
00:14:29.540 | To get to 37,
00:14:31.660 | look how much time I have to spend versus down here where 36,
00:14:36.060 | I spend a fraction of the time.
00:14:38.100 | That is absolutely something that will matter to the search experience of users.
00:14:42.820 | There is almost no way they're waiting around for 691 milliseconds, for example.
00:14:49.060 | Or what about the index size?
00:14:50.660 | Right. If you care about space footprint and you will if you are indexing the web,
00:14:55.220 | some of these have tiny little indices and then, uh-oh,
00:14:58.980 | that's our model, Colbert V1, 154 gigabytes.
00:15:03.180 | Right. So now if you need to hold it in memory,
00:15:05.900 | your world just got a lot more expensive.
00:15:08.980 | You'll see over here RAM requirements.
00:15:11.660 | So BM25, it has no hardware requirements at all.
00:15:16.100 | You can run that on anything.
00:15:17.880 | Whereas these models down here that have these really high MRR scores,
00:15:22.220 | hugely expensive in terms of compute.
00:15:25.460 | Classic story of the neural age, right?
00:15:28.420 | So you have to pay somewhere.
00:15:30.700 | Then of course, I hope you're thinking about this.
00:15:34.260 | So then what is the best combination of all these things?
00:15:38.140 | Well, it depends on how much money you have and how much time you have,
00:15:41.180 | and how much you care about accuracy.
00:15:43.620 | And so the best pitch I can make to you is that as you evaluate systems,
00:15:47.880 | you think about what you care about,
00:15:50.140 | what matters, and conduct- construct your evaluations on that basis.
00:15:54.800 | That's gonna be a big theme of the course later on.
00:15:57.560 | And I'm hoping to time it so that you all,
00:16:00.520 | for your papers, are thinking about assessment.
00:16:03.120 | And you think, ah, you know,
00:16:04.480 | I should have a whole section about my philosophy of assessment here and not
00:16:07.720 | just fall into F1 or fall into success at K or whatever is relevant.
00:16:14.460 | This is kind of interesting too.
00:16:19.580 | This is from the same paper.
00:16:21.220 | Here's BM25.
00:16:23.500 | Costs essentially nothing,
00:16:25.460 | but it has very low performance.
00:16:27.300 | If you travel straight up from there,
00:16:29.580 | look at these splayed models.
00:16:31.660 | Also costing essentially nothing,
00:16:34.140 | but vastly better in terms of their performance.
00:16:36.940 | That looks like a real discovery to me.
00:16:39.160 | You know, this is like the Pareto frontier as they call it.
00:16:42.520 | These systems where you would just wouldn't
00:16:45.040 | choose any that are off the frontier no matter what your values.
00:16:48.520 | And obviously, you can see that to favor this model,
00:16:51.720 | there are gonna have to be other dimensions that we care about beyond cost and MRR,
00:16:56.040 | because otherwise, that's just not a choice you would make.
00:16:59.880 | But for all I know,
00:17:01.920 | there are hidden dimensions that need to be teased out that would
00:17:04.280 | show that that ANS model is the best relative to.
00:17:10.760 | Let's dive into some of those models then.
00:17:15.280 | Neural IR. First, we'll start with cross-encoders.
00:17:20.840 | This will be very intuitive.
00:17:22.760 | Okay. Here, just imagine I have a huge transformer.
00:17:26.880 | And for cross-encoders, what I do is,
00:17:29.640 | I just concatenate the query and the document together,
00:17:32.780 | process them with my transformer model.
00:17:35.440 | And then on the top here,
00:17:36.800 | I put a little scoring function.
00:17:38.440 | And the scoring function will just say,
00:17:40.600 | for this query, how good is this document?
00:17:44.000 | Enormously powerful to this comment from before,
00:17:48.800 | we are making like maximal use of this,
00:17:52.320 | say, BERT model here to get
00:17:54.600 | every possible interaction between query and document.
00:17:57.880 | So this will be good in terms of accuracy.
00:18:00.600 | But you might worry about some other things.
00:18:03.140 | Here, let me walk through a bit more.
00:18:04.620 | In the background here,
00:18:05.740 | I'm assuming that our dataset looks like this.
00:18:08.220 | We have a query, one positive document,
00:18:11.660 | and a set of one or more negative documents.
00:18:15.500 | We could have multiple of the negatives.
00:18:18.740 | What I'm depicting on the left here is a model we could summarize like this.
00:18:24.040 | This is the encoder.
00:18:25.980 | We concatenate the query and the document,
00:18:28.980 | and process them, and we retrieve this representation here,
00:18:33.340 | layer n, position 0.
00:18:36.020 | We feed that through a dense layer that does our scoring,
00:18:39.700 | and that is the basis for essentially a classifier.
00:18:44.060 | This is called the negative log likelihood of the positive passage.
00:18:47.860 | And if you squint or you don't squint,
00:18:50.540 | you just let it go blurry,
00:18:51.940 | you will see that it is a typical classifier loss.
00:18:55.420 | The only possible twist is the denominator is
00:18:58.620 | the positive passage score and then on the denominator,
00:19:02.020 | I have the positive passage sum together with
00:19:04.380 | all the negative passages that I have in my example set.
00:19:07.940 | But fundamentally, it's a classifier.
00:19:11.340 | So that's why this examples look like this because that's what's being used
00:19:16.980 | here to optimize all these parameters to score documents.
00:19:22.780 | Final thing, I hope you're thinking about this.
00:19:26.740 | It's going to be incredibly expressive and powerful,
00:19:29.380 | but it just won't scale.
00:19:31.740 | The cost of having the query and document interact at
00:19:34.940 | query time is that I can't process any of these documents ahead of time.
00:19:38.860 | So just imagine this,
00:19:40.140 | your query comes in on the web,
00:19:41.860 | like your Google, you're using a cross encoder, the user queries.
00:19:45.820 | You need to process that query together with every single document on the web,
00:19:50.220 | to score them, and then on that basis,
00:19:52.660 | you will get beautiful scores.
00:19:54.660 | But obviously, each query could take years to serve.
00:19:59.380 | So from this perspective,
00:20:01.700 | it is just not a practical choice.
00:20:03.940 | Maybe we could use it for re-ranking.
00:20:06.620 | You see this sometimes where a cheap retriever gets a lot of like,
00:20:10.220 | like a 1,000 documents and then this is done to re-rank the last 1,000.
00:20:13.980 | But we can't do this at web scale.
00:20:16.700 | So a question in the back. Yeah.
00:20:18.540 | Um, could you use this with multiple possible,
00:20:21.500 | uh, positive documents as well?
00:20:25.540 | Like if you were like, like for example,
00:20:27.820 | for like the ranking thing right here,
00:20:29.340 | like multiple of those could be like good, but.
00:20:33.700 | I- let's see.
00:20:36.180 | I don't see why not.
00:20:37.100 | The numerator could be the sum of the positive and
00:20:39.340 | then the denominator could just include all of those.
00:20:41.780 | So what you would be doing is, well,
00:20:46.220 | I'm just trying to think through.
00:20:51.900 | That, that would be one approach.
00:20:53.060 | The other approach would be to just treat them as separate examples.
00:20:56.700 | I think under some conditions,
00:20:58.620 | those will be identical,
00:20:59.700 | but I'd have to think it through.
00:21:00.780 | But I don't see a problem.
00:21:04.140 | I don't see a problem.
00:21:06.220 | But it's worth thinking about. I'll get back to you on that.
00:21:10.140 | Let's improve on this.
00:21:13.660 | DPR, dense passage retriever.
00:21:17.020 | This will also be intuitive.
00:21:18.260 | Here we go. Query and document,
00:21:20.020 | except notice now,
00:21:21.580 | they are processed by separate models.
00:21:23.900 | The query encoder and the document encoder.
00:21:26.340 | Could be the same parameters,
00:21:27.700 | but the point is, we process them separately.
00:21:30.280 | And I've made lighter every state except the output tokens,
00:21:34.980 | but below the two class tokens,
00:21:37.120 | because those are the only ones that we need.
00:21:39.660 | Okay. These two.
00:21:41.700 | And then we do some kind of scoring on that basis like similarity.
00:21:45.740 | Right. So here are examples are the same.
00:21:48.660 | Now, the similarity function as I'm calling it for,
00:21:52.140 | for a query and a document is we process the query,
00:21:54.980 | we get this guy, process the document,
00:21:57.540 | and we get this guy,
00:21:58.780 | and then we do scoring on that basis.
00:22:00.900 | There are no additional parameters.
00:22:03.100 | We just score based on those representations.
00:22:06.420 | They're dot product.
00:22:08.260 | So now, we've got something that is highly scalable because we can process
00:22:13.740 | every document in our entire web collection into a single vector,
00:22:17.660 | this one, and it can just sit there on disk.
00:22:20.540 | And then at query time,
00:22:22.100 | process the query, get its vector,
00:22:24.000 | and do this super fast dot product comparison for scoring.
00:22:28.620 | So now we've got something that is probably even going to
00:22:31.380 | function as a full ranking model, not just a re-ranker.
00:22:35.380 | But the real game is that we can process all our documents offline.
00:22:39.140 | The cost was that we now have
00:22:41.500 | almost no interactions between the query and the document.
00:22:44.460 | Like if you think about token identities,
00:22:46.920 | if you think about the soft matching that happens with TF-IDF,
00:22:50.940 | none of that is going to be able to happen here.
00:22:53.900 | That was the cost. Yeah.
00:22:59.900 | So, uh, I mean,
00:23:01.540 | essentially what we'll be comparing would be two fixed length vectors,
00:23:05.020 | right? I mean, a vector representing the document and another representing the query.
00:23:08.860 | So is there, I mean,
00:23:11.260 | a limit to the length of the document that would be represented in that vector?
00:23:15.580 | I mean, like, could it represent an arbitrary,
00:23:18.260 | a long document? It would lose context.
00:23:22.420 | These are great questions. Let me repeat them.
00:23:24.060 | For the first question, yes, you are right.
00:23:26.700 | The one constraint we need to impose on the query encoder and
00:23:29.900 | the document encoder is that they have
00:23:31.420 | the same dimensionality so that we can do the dot product.
00:23:34.160 | They can otherwise be separate models if we want.
00:23:37.040 | And then length of query and length of document,
00:23:39.840 | that's just going to be imposed by whatever we
00:23:42.020 | choose for the query and the document themselves.
00:23:44.240 | So if you choose BERT,
00:23:45.860 | you're going to be stuck with 512 as the length,
00:23:47.980 | the longest document that you can process unless we
00:23:50.620 | do some further manipulation of these things. Yeah.
00:23:56.700 | If these models are trained to project into kind of like a shared embedding space,
00:24:03.060 | so like documents that are similar to a query are going to
00:24:06.040 | fall into a similar location in embedding space,
00:24:09.700 | could we have a system where we essentially take,
00:24:12.840 | like we pre-process all the documents,
00:24:15.240 | we take a query at inference time,
00:24:17.840 | project it into embedding space and then do like a new research or something like that?
00:24:24.200 | Well, yes. So some aspects of what you're describing,
00:24:27.200 | I think, are what DPR will be optimized for.
00:24:30.000 | The other parts of what you're saying are going to be
00:24:32.440 | optimization tricks that I show you in a second, I believe. Yes.
00:24:37.960 | You elaborate on what you mean by limited query doc interactions?
00:24:41.880 | Just that all we've got in the end is this vector for
00:24:45.860 | the whole query and this vector for the whole document.
00:24:49.020 | So token identities to the extent that they're preserved at all,
00:24:52.400 | they have to have been packed into those vectors.
00:24:55.900 | Whereas over here, we had
00:25:00.040 | every token level interaction you can imagine as a result of us using like the transformer.
00:25:06.000 | Yeah.
00:25:09.160 | Is there any room for training some more clever synthesis of
00:25:13.120 | the two representations you get at the end as opposed to just dot producting them?
00:25:16.480 | Oh, that's, yeah.
00:25:17.760 | I think that's a natural follow-on is that you might think,
00:25:20.660 | I want to have in this layer here some additional parameters,
00:25:24.260 | and you can kind of see how that might work, right?
00:25:26.560 | So instead of just using this vector,
00:25:28.260 | I would put some parameters on top and then the same optimization can be used.
00:25:33.560 | Yeah.
00:25:34.160 | If they're going to the same embedding space, attention.
00:25:39.040 | Yeah.
00:25:40.080 | Yeah. Yeah. And this could be good in the sense that we,
00:25:44.280 | we would pay a little bit of a cost by adding more parameters,
00:25:47.440 | but we might gain something in terms of expressivity.
00:25:51.320 | Nice. Let me show you a happy compromise.
00:25:58.420 | Oh yeah, I just wanted to point this out,
00:26:00.900 | that I've just showed you two loss functions.
00:26:03.500 | I showed you the cross encoder and the DPR,
00:26:06.000 | and you can probably already see that they are identical
00:26:08.800 | except for this function that you might call comp here.
00:26:12.340 | And that's kind of freeing.
00:26:14.100 | And as you think about these different model architectures,
00:26:16.980 | probably what you're thinking about is simply changing
00:26:19.860 | this comp function and then using your available data to
00:26:23.200 | train the model against this negative log likelihood of the positive passage.
00:26:27.720 | There are other losses out there in the literature,
00:26:31.160 | but this is the most widely used and it seems to be very effective.
00:26:35.760 | Colbert. This stands for contextualized late interaction with BERT.
00:26:44.120 | It was invented by Omar and Matei who are here.
00:26:46.760 | Omar is my student, I work closely with Matei.
00:26:49.040 | Um, and let's see.
00:26:51.040 | So Omar would want you to know that this stands for contextualized late interaction,
00:26:56.120 | and he pronounces it Colbert because Stephen Colbert has
00:26:59.520 | a show full of contextualized late night interactions.
00:27:04.120 | But you can also pronounce it Colbert.
00:27:07.340 | It's your choice because the BERT there is the BERT model.
00:27:11.680 | And we are, yes, still hoping that Stephen Colbert will take notice of this.
00:27:16.040 | [inaudible]
00:27:25.200 | I haven't been so bold,
00:27:26.440 | but I welcome you all to do that.
00:27:28.760 | That's great. Add him on Twitter, yes.
00:27:31.640 | Here's how this will work. I've drawn
00:27:33.920 | the query encoder on the side for reasons that you'll see,
00:27:36.320 | but it's the same kind of thing.
00:27:37.520 | So imagine BERT processes my query and I've grayed out everything,
00:27:41.000 | but the final representations because crucially,
00:27:43.280 | those are the only ones that we actually need.
00:27:45.480 | Same thing with the document,
00:27:47.400 | and it could be the same encoder.
00:27:49.300 | Now what I'm going to do with Colbert is form a grid of scores.
00:27:54.040 | And this is going to essentially give the similarity value
00:27:57.640 | between every query token and every document token.
00:28:02.560 | And then I will choose the values along the rows,
00:28:07.400 | that is for each query,
00:28:08.480 | the document token that maximizes that similarity comparison.
00:28:12.880 | And the scoring function is essentially the sum of those three max values.
00:28:18.800 | That is why you see maxim all over the place for Colbert.
00:28:23.760 | Examples are as before,
00:28:26.320 | losses as before,
00:28:28.240 | and I wrote down here what we would think of as the comp function,
00:28:31.200 | and I wrote it as maxim.
00:28:32.760 | For a query in a document,
00:28:34.200 | you sum over all the query tokens,
00:28:36.640 | and you get the max matching document token.
00:28:40.760 | So you can see why it's contextualized late interaction,
00:28:44.120 | because I'm using the output states.
00:28:46.520 | But unlike DPR, I'm allowing them all to interact with
00:28:50.360 | each other via these very fast maxim calculations.
00:28:54.120 | I have token level interactions.
00:28:56.760 | It's right, so highly scalable and highly expressive.
00:29:00.440 | The only cost is that the interactions happen only in this very thin final layer.
00:29:06.320 | But this is really pleasing for IR,
00:29:10.120 | that Colbert because of this brings us back to common intuitions in IR.
00:29:14.960 | We do genuinely achieve with these maxim scores intuitive soft alignments.
00:29:20.120 | Here I have the query, when did the Transformers cartoon series come out?
00:29:24.000 | And the response document,
00:29:25.680 | the animated Transformers was released in August 1986.
00:29:29.480 | And these are proportional to actual maxim values query relative to document.
00:29:35.400 | The thickness of that line.
00:29:37.160 | And you can see that it is doing something very intuitive,
00:29:40.080 | and also something very semantic.
00:29:42.440 | Because unlike term-based models,
00:29:44.240 | I don't have to do anything special to capture the fact that
00:29:47.600 | come in the context of come out is a lot like released.
00:29:52.400 | And similarly with when and that date.
00:29:55.000 | Here I'm showing the two topmost maxim values,
00:29:58.060 | and they're also very intuitive.
00:30:00.040 | And this is wonderful because IR has been so successful for so long,
00:30:04.200 | doing term matching, and it is nice to see that intuition
00:30:07.920 | carried forward into this more semantic space in my view.
00:30:12.520 | So I can go up. Yeah.
00:30:14.880 | Is that matrix of like query to document mapping,
00:30:19.760 | is that why the in-memory index is very purple there?
00:30:23.360 | Yes. So your question is,
00:30:25.520 | why is the index for Colbert so big?
00:30:27.840 | It is because we have to store every token level representation.
00:30:32.080 | Yes. I'm gonna, I'm gonna show you that we can do better,
00:30:34.360 | but naively storing these for
00:30:36.960 | our entire document store is gonna be a lot of vectors.
00:30:40.600 | One per token, not one per type, one per token.
00:30:46.040 | Yeah. I have to pay somewhere.
00:30:50.520 | I guess that's the insight. Yeah. Question.
00:30:53.240 | Maybe there's some intuition that can make this a little clearer.
00:30:57.280 | If the document has multiple variants of,
00:31:00.320 | of transformers or, you know,
00:31:02.200 | Decepticon, Optimus Prime or whatever,
00:31:04.360 | they, they're all related to the original token transformers.
00:31:07.840 | Would you be able to kind of draw that link to all those other tokens as well,
00:31:12.040 | or would you have to pick,
00:31:14.200 | I guess, one relationship?
00:31:17.080 | Transformers is represented once.
00:31:19.520 | I think that's a great question.
00:31:21.960 | So transformers, that's why I picked it.
00:31:23.960 | It's amusingly ambiguous between our model and the animated cartoon series.
00:31:29.840 | Um, of my youth.
00:31:32.240 | So, but I think the point is that because it's BERT,
00:31:36.760 | transformers in this context will have
00:31:39.320 | a very different representation from the one if we were talking about NLP.
00:31:43.200 | And that's why it's so good that we are using BERT because then we'll
00:31:46.280 | get Maxims that are appropriately semantic. That is the hope.
00:31:50.600 | Yeah. Whereas term-based models really gonna struggle.
00:31:53.560 | The best they're gonna be able to do is have
00:31:55.840 | engrams that kind of capture the fact that this transformers is the cartoon series one.
00:32:01.840 | So another, actually, that's another argument in favor of being in a more semantic space.
00:32:08.760 | I want to just quickly talk with you about how we have worked to optimize Colbert,
00:32:15.880 | because I think that this suggests things that you would want to do if you
00:32:18.600 | developed your own neural retrieval models because the hard truth here is that
00:32:23.080 | BM25 is blazingly fast and scalable and these neural models are not.
00:32:29.800 | You have to work much harder to get them to that point of being as performant
00:32:33.240 | in terms of other dimensions beyond accuracy.
00:32:36.660 | We could use Colbert as a re-ranker as I alluded to before, right?
00:32:40.960 | So here I have all these token level representations which I do have to store,
00:32:46.040 | and they're each connected to a document.
00:32:48.880 | Now, on- if used naively,
00:32:51.400 | this will be not scalable,
00:32:53.080 | but I could do this.
00:32:54.120 | Given some query, uh,
00:32:55.920 | that's represented as a sequence of tokens, uh,
00:32:59.080 | I could get the top k documents for it using like BM25 and then re-rank that top k.
00:33:05.440 | [NOISE]
00:33:06.760 | And so if k is small,
00:33:08.560 | I pay the full price of the Colbert model but only for k documents.
00:33:12.840 | And you're hoping that BM25 did a good job of getting you to that initial point.
00:33:18.240 | It's a very common application and it can be really meaningful to
00:33:21.960 | re-rank that final set of k documents.
00:33:25.280 | But we could do a little better.
00:33:26.960 | If we wanted to use Colbert end-to-end,
00:33:28.760 | here's how we could work.
00:33:29.960 | We again store all those token level vectors,
00:33:33.020 | but now we're gonna kind of turn things around.
00:33:35.160 | We just need to keep track of those vectors and their associated documents.
00:33:39.820 | For a query that we have encoded as a set of vectors using Colbert,
00:33:44.840 | we take each query vector wi,
00:33:47.940 | and we retrieve the p most token vectors from
00:33:50.960 | this huge list that are similar to our target.
00:33:54.360 | And that doesn't require the full Colbert model.
00:33:56.840 | That could just be a similarity calculation and you can do those really fast.
00:34:01.400 | People have really optimized that.
00:34:03.200 | And then you get all the documents associated
00:34:06.120 | with this small set of vectors that you found and you score them.
00:34:09.640 | So again, the name of the game is to use Colbert only very
00:34:13.680 | sparingly in a final stage because it is so expensive.
00:34:18.780 | And then a third step here,
00:34:21.240 | just quickly, we can do even better and this is quite striking.
00:34:24.860 | What we can do is cluster our token level representations
00:34:28.940 | into their centroids using k-means clustering.
00:34:31.980 | That's what I've got in red here.
00:34:34.260 | And then use them as the basis for search.
00:34:36.860 | So again, we encode our query into a series of vectors.
00:34:40.300 | And then for this target vector wi,
00:34:43.060 | we first get the centroids that are closest to that.
00:34:46.340 | And this is important because in practice,
00:34:48.620 | we can collect only like four centroids per token vector and do really well.
00:34:54.740 | That's a tiny number.
00:34:56.420 | Then we get the t most similar token vectors to that centroid.
00:35:01.580 | And then we finally do scoring on the associated documents.
00:35:05.980 | And so by leaps and bounds here,
00:35:08.980 | we have reduced the amount of compute we need to do with this huge index
00:35:13.300 | by using the centroids and then using Colbert again very sparingly.
00:35:18.660 | Final thing and then I'll take some questions.
00:35:22.020 | The team has worked very hard to reduce the latency of Colbert.
00:35:25.940 | This is a latency analysis here.
00:35:27.900 | And the thing I want to point out to you is that the Colbert model steps,
00:35:30.780 | actually for this second version,
00:35:32.640 | the one I just described to you with the centroids,
00:35:34.900 | that was actually a relatively small part of
00:35:37.260 | the overall cost because it was being used so sparingly.
00:35:40.380 | The big costs were being in the- dealing with
00:35:44.460 | the huge index and also doing work to quantize the vectors,
00:35:49.940 | so that they were easier to store on disk by making them smaller.
00:35:53.520 | And so after a bunch of work with this framework called Plaid,
00:35:57.580 | they were able to get rid of almost all of that index lookup and
00:36:01.380 | de-quantization or decompression steps for the vectors that were costing so much.
00:36:05.900 | And they brought the latency down to like 58 milliseconds.
00:36:10.620 | Which- so it went from something that is impossible to imagine deploying
00:36:14.980 | industrially to something that is close to
00:36:17.820 | what you might entertain as a possibility for deployment.
00:36:21.100 | And I- you know, the details are in the Plaid paper.
00:36:23.900 | We can talk about them offline.
00:36:25.500 | I just wanted to call out that I think this is an incredible achievement.
00:36:29.100 | It is so clever,
00:36:30.980 | the set of things that they did to achieve this enormous improvement.
00:36:35.300 | So shout out to them.
00:36:36.580 | And it does mean that if you had heard a rumor that Colbert was
00:36:40.580 | impractical to use because the index was too large and the latency was too long,
00:36:44.620 | I think it's not true anymore.
00:36:46.500 | The indices are small because of quantization,
00:36:49.180 | and this is that picture of latency.
00:36:51.300 | So give it a shot.
00:36:54.100 | I have one more model, but let me take questions.
00:36:57.020 | Yeah. Did you have a question?
00:36:58.380 | Oh, sorry. I just had a question about the latency and also the predictor.
00:37:01.700 | Okay. Cool.
00:37:02.740 | Yeah, I'm happy to talk more.
00:37:03.860 | The Plaid paper is full of tricks and things like that.
00:37:07.020 | I don't want to take up too much time.
00:37:08.580 | I definitely want to give Sid plenty of time to talk about models.
00:37:10.900 | So let me just show you one more.
00:37:12.340 | This is SPLADE. This is also ingenious.
00:37:14.540 | It'll get you thinking in a new way.
00:37:16.420 | Okay. So for SPLADE,
00:37:18.740 | I wrote sequence at the bottom because we're going to do this for
00:37:21.780 | both queries and documents, this process.
00:37:24.780 | And crucially, here I have the vocabulary.
00:37:28.620 | I've only represented seven tokens,
00:37:30.560 | but if it was BERT, it would be like 30,000.
00:37:32.980 | Okay. So we again process the text into the output states,
00:37:38.540 | T1 through T3 there.
00:37:40.980 | And then we form all these scores.
00:37:44.020 | And the scores are determined by this thing here, that's SI sub J.
00:37:48.420 | So we're going to apply a linear layer to the encoding,
00:37:51.900 | those output states, and we're going to combine it with
00:37:55.500 | the embedding for these vocabulary items with a bias.
00:37:59.300 | So if you strip away the details,
00:38:01.060 | you can see that this is like a dot product of
00:38:03.540 | these states with all of these values here in our vocabulary.
00:38:08.660 | And then SPLADE is the sum of that.
00:38:12.060 | And so you can think of that as summing across all the document tokens.
00:38:15.260 | And so what we've got in that orange column there is a probably very sparse vector
00:38:22.140 | that represents this text down here with respect to our vocabulary.
00:38:29.260 | So this is a lot like term-based, uh, work of old, right?
00:38:38.500 | This is a lot like a TF-IDF representation,
00:38:41.320 | except it was done in the neural space.
00:38:43.420 | So we should get the advantages of being with a semantic model.
00:38:48.180 | And then the similarity value is just the SPLADE representation,
00:38:52.580 | that is this representation here for the query dot product with the document.
00:38:57.820 | And the loss is the one that we've been using all along.
00:39:02.100 | So just to be clear, so you do the SPLADE process both with the query and with the document.
00:39:11.140 | And then, okay, cool.
00:39:12.940 | Yeah, that's it. There's a bunch of- it looks similar in my document.
00:39:15.980 | This is great. Let me review what you just said.
00:39:17.660 | There's a bunch of new things.
00:39:18.860 | Sequence, not query or document because we do this for both kinds.
00:39:22.900 | And of course, we can do all the documents ahead of time.
00:39:26.180 | The big twist is that we're scoring these sequences with respect to the vocabulary.
00:39:31.140 | And we are essentially getting in semantic space because this is an embedding space here,
00:39:36.020 | and this is a contextual embedding space here,
00:39:38.500 | scores for each query term with respect to the whole vocabulary.
00:39:44.180 | That gives us this big,
00:39:45.940 | presumably pretty sparse vector,
00:39:47.980 | and their optimization further encourages sparsity.
00:39:51.300 | And then the similarity value is the dot product of those for queries and for documents.
00:39:56.580 | So it has some hallmarks of late interaction,
00:39:59.860 | except it is interacting the text representations with the vocabulary,
00:40:05.220 | kind of like what you get with TF-IDF.
00:40:08.300 | And this model is outstanding.
00:40:12.020 | You saw it in some of my doc- of my slides before, very impressive.
00:40:16.860 | And it's also a new way of thinking, which I really like.
00:40:21.820 | Here's a bunch of more recent developments,
00:40:24.780 | and one theme of them, I won't go through them,
00:40:26.620 | is just that people are working hard finally on making these models more efficient.
00:40:31.780 | So a big theme of this is not just obsession with accuracy,
00:40:35.420 | but also obsession with especially latency.
00:40:39.620 | And then finally, for that paper that I mentioned before,
00:40:43.300 | we did a bunch of systematic investigations of different approaches.
00:40:48.140 | You can see BM25, DPR, Colbert,
00:40:50.780 | and some splayed models here.
00:40:52.540 | And these are all kind of variants of
00:40:54.580 | these models where people have worked hard to optimize them.
00:40:57.580 | There's lots of tables like this in the paper.
00:41:00.060 | Let me just draw out a few comparisons.
00:41:02.080 | BM25 is the only solution that could run on this tiny hardware here.
00:41:07.720 | We couldn't even run the other systems.
00:41:09.900 | That's why it's alone in its own little block there.
00:41:12.580 | And it costs nothing.
00:41:15.260 | Right. But it's not that successful either.
00:41:18.300 | Success at 10 is low relative to the rest.
00:41:21.060 | When we move here,
00:41:22.700 | this is sort of interesting.
00:41:24.260 | These two Colbert models,
00:41:26.500 | uh, achieve very similar performance.
00:41:29.260 | If you look all the way to the right,
00:41:30.900 | except one of them is double the latency of the other one for this hardware.
00:41:36.900 | And so you might wonder, do I really need this extra point of performance?
00:41:42.300 | If I'm gonna have to wait that long.
00:41:44.620 | And then if you look to splayed,
00:41:46.540 | so splayed is below Colbert v2 small,
00:41:50.180 | but its latency is a quarter or something of the Colbert v2 small.
00:41:57.780 | So maybe you care more about that and not so much about the success.
00:42:01.380 | And then if you compare these two splayed, right,
00:42:03.780 | they have the same performance.
00:42:05.540 | But if you just jack up the hardware a little bit,
00:42:09.580 | then you get much lower latency.
00:42:13.140 | But look how much the price went up.
00:42:16.660 | It went up for all of them with this heavy duty hardware.
00:42:20.420 | Uh, yeah.
00:42:22.500 | So this is the space that you're actually operating in.
00:42:25.020 | I'll- we'll talk later about how we might more systematically integrate all these scores.
00:42:29.940 | I think this is enough now to get you thinking about all of these dimensions.
00:42:34.180 | [inaudible]
00:42:35.980 | They are the Plaid Colbert.
00:42:37.140 | Yeah. Pretty expensive there.
00:42:38.900 | Luckily, in the paper,
00:42:40.020 | we show that you never need a GPU for Colbert, I believe.
00:42:42.900 | You just- so you can always use cheaper hardware.
00:42:45.860 | Yeah. But those costs do look scary.
00:42:49.300 | [LAUGHTER]
00:42:52.180 | The final section of this is just some datasets.
00:42:54.220 | I think I don't need to go through it because you have it as a resource.
00:42:56.660 | If you want to get started in neural information retrieval,
00:42:59.500 | you've got T-REC, MS Marko,
00:43:02.140 | and then there are a bunch of new benchmarks that are
00:43:04.420 | designed to assess systems out of the box, that is zero-shot.
00:43:07.540 | Beer is great. Latte is great for long-tailed,
00:43:11.260 | topic stratified evaluation, and then this XOR tie-dye is cool because this is multilingual.
00:43:17.140 | And I know you have expressed interest in multilingual stuff.
00:43:20.060 | This could be a great playground for doing that with kind of QA and retrieval,
00:43:24.260 | like open QA as we've been doing it.
00:43:26.860 | Bunch of other topics.
00:43:29.140 | I think the bottom line here is just,
00:43:32.820 | again, this is like a refrain in this class.
00:43:35.620 | NLU and IR are back together again after being apart for so long,
00:43:40.340 | and this is having profound implications for research and technology development.
00:43:44.860 | So this is absolutely a very exciting moment to participate in this research because there is
00:43:50.420 | so much innovation yet to happen and it is having
00:43:54.060 | such an impact on research and also out in the wider world.
00:43:59.100 | Excellent. All right.
00:44:01.540 | Sid, want to take over?
00:44:03.260 | So it's cool. It's like retrieval isn't just hitting NLU,
00:44:06.380 | it's hitting everywhere, like vision and robotics as of like,
00:44:10.260 | this week we're starting to use retrieval methods to do.
00:44:13.660 | What's the best way to figure out how to do a new task?
00:44:16.820 | Maybe retrieve some examples of a robot or a human doing
00:44:20.460 | the same task and then generating your actions.
00:44:23.060 | So cool stuff. Cool. All right.
00:44:27.860 | Let's see if this works.
00:44:30.580 | Yeah. All right. So I'm going to kind of pick up or try to pick up where I left off
00:44:39.060 | last week and kind of give you this evolution,
00:44:41.820 | this history lesson on how we got to the transformer,
00:44:45.420 | and then go from there into tips and tricks for training big models generally,
00:44:50.620 | and then end with like a small little teaser on
00:44:53.020 | fine tuning and parameter efficient tuning.
00:44:56.060 | So you can use that in your projects down the road.
00:44:58.780 | Cool. So just to kind of blaze past things,
00:45:01.780 | I kind of started by talking through
00:45:04.980 | where things were pre-2017 when
00:45:07.740 | the transformer paper came out on both the RNN and the CNN side,
00:45:11.340 | and tied a lot of the innovation around the transformer
00:45:15.380 | to how modern convolutional neural nets,
00:45:20.500 | specifically residual nets were working,
00:45:22.300 | and the connections there were closer than the connections to RNNs.
00:45:26.300 | Kind of walk through how we got to the self-attention block with this fancy code,
00:45:31.860 | which is basically just saying like you're
00:45:33.580 | splitting your heads and you can kind of think of your heads in
00:45:36.300 | a self-attention block as the different kind of kernels or filters in a CNN layer.
00:45:41.460 | Then kind of closing with like this full self-attention block,
00:45:45.860 | where we're actually doing the RNN style attention,
00:45:48.500 | and then this question of this non-linearity that we're adding at the end.
00:45:52.500 | Because without this non-linearity and
00:45:54.780 | this sort of MLP that we're adding to the end of each transformer block,
00:45:58.180 | we're really just doing weighted averages of linear transforms of values.
00:46:02.820 | Okay. So, if we kind of take this as ground truth,
00:46:09.060 | starting point for what a transformer block looks like,
00:46:12.140 | very much inspired by the ideas of CNNs and RNNs with attention at the time.
00:46:17.580 | We have this residual connection here,
00:46:20.740 | which is kind of just adding X over and over
00:46:23.380 | again as we stack more and more layers together.
00:46:25.860 | There's a problem. Can anyone spot the problem in this implementation by itself?
00:46:31.420 | So, the problem is that activations blow up.
00:46:34.700 | We keep adding the same input over and over again as we go deeper.
00:46:38.980 | Eventually, specifically in the RNN attention layer,
00:46:43.700 | when we take this dot product between the queries and the keys,
00:46:46.460 | we're going to get overflow.
00:46:48.260 | So, we need to do something about that.
00:46:51.020 | All right. So, while the first part of
00:46:53.340 | kind of building the transformer layer is very,
00:46:54.940 | very much inspired by history,
00:46:56.680 | the second part is just trying to make sure it doesn't fail,
00:46:59.620 | and doesn't blow up, and doesn't crash when we try training it.
00:47:02.300 | So, what's one thing that we can do to kind of avoid
00:47:05.140 | this sort of blow up of our activations?
00:47:09.540 | So, layer normalization.
00:47:11.260 | So, layer normalization, maybe batch norm and layer norm were covered earlier on,
00:47:16.180 | is a very, very simple idea.
00:47:17.900 | Along each feature dimension,
00:47:19.540 | we're just going to normalize so that each feature has mean zero,
00:47:23.340 | standard deviation one, which means that every time we add a residual connection,
00:47:28.300 | we're going to normalize so that everything comes back to a decent space.
00:47:31.300 | We're still able to learn the kind of same level of expressivity we care about.
00:47:35.140 | We're just not necessarily going to
00:47:37.900 | keep blowing up or growing the magnitude of our activations.
00:47:41.660 | What that looks like is two calls to
00:47:45.860 | NN dot layer norm with the dimensionality of our transformer,
00:47:49.660 | and then adding that into a res block.
00:47:51.940 | We're just going to normalize each X before we pass it
00:47:54.300 | into the attention and the MLP layers respectively.
00:47:57.420 | Now, there's a problem with this that isn't obvious,
00:48:00.860 | and actually wasn't obvious to the people building transformers at the time.
00:48:03.580 | It wasn't really explained kind of till three years later,
00:48:06.900 | which is that you have optimization issues when you do this.
00:48:12.340 | Specifically, if you just try to optimize
00:48:14.300 | the naive transformer with this layer norm in place,
00:48:17.420 | with kind of conventional ML wisdom,
00:48:19.500 | which is like learning rate decay or a constant learning rate,
00:48:23.900 | bad things happen.
00:48:25.740 | Specifically, I'm going to use the hugging face emojis,
00:48:27.980 | my stand in for a transformer.
00:48:30.180 | Stuff happens. The optimization crashes.
00:48:33.500 | It's either exploding gradients,
00:48:35.300 | it's either vanishing gradients.
00:48:36.340 | If you ask someone in 2018 or 2019 or 2020,
00:48:38.540 | they would tell you one or the other is happening,
00:48:40.740 | but there are definitely no gradients that are like
00:48:42.380 | stable throughout the training process.
00:48:44.660 | So, you introduce this kind of weird thing.
00:48:47.580 | It kind of comes out of almost nowhere,
00:48:49.540 | which is like this transform,
00:48:50.980 | this like warm-up schedule that you
00:48:54.100 | see a lot of the time in
00:48:57.020 | like any code for training or
00:48:58.500 | even fine-tuning transformers these days.
00:49:00.540 | Now, this is actually just like fun because I have the time.
00:49:04.060 | I'm going to like go through it.
00:49:05.660 | Who came up with this?
00:49:07.980 | >> I'm thinking, like I think I remember in the [inaudible] paper,
00:49:21.100 | they had like a weird learning rate, but I don't remember.
00:49:23.380 | >> So, it is in the original transformer papers,
00:49:26.300 | like the main thing that they get to get this stable.
00:49:28.980 | So, it's one of the authors that came up with it.
00:49:30.940 | But if you actually run a Git blame
00:49:33.580 | on the first ever transformer code base from Google,
00:49:35.780 | the Tensor2Tensor code base,
00:49:37.220 | like in the very first commit,
00:49:39.100 | in like the R parse like flags for the different optimizers you use,
00:49:44.340 | there's one option just called Gnome, after Gnome Shazier.
00:49:47.540 | And in the annotated transformer,
00:49:49.460 | like the very first like block host,
00:49:50.820 | that's what the optimizer is called.
00:49:52.020 | It's called Gnome Opt in Sasha Rush's code.
00:49:54.700 | And it's called the Gnome Optimizer for a really long time,
00:49:57.620 | until they just decided to call it just like,
00:49:59.580 | you know, linear warmup, cosine decay.
00:50:02.460 | And so, Gnome Shazier kind of came up with it.
00:50:05.420 | And if you were to kind of go back and think about like
00:50:08.340 | the sorts of problems and the papers he was working on at the time,
00:50:11.820 | he was actually doing a lot of stuff with
00:50:13.820 | the different types of like gradient descent optimizers,
00:50:19.660 | like RMSProp, Adafactor came out like a year after the transformer paper came out.
00:50:25.060 | And he was really like interested in like looking at this problem of like,
00:50:28.060 | "Huh, okay, weights seem to be where like if you just
00:50:30.900 | really inspect the gradients early on with like this layer norm thing,
00:50:34.340 | variance seems to be high and you kind of want to burn that in."
00:50:37.580 | And he's seeing this for LSTM,
00:50:38.900 | so he's kind of doing this already for his LSTM work.
00:50:42.460 | And then he just like, "Let's try this."
00:50:46.020 | It worked and no one really questioned it.
00:50:48.620 | It breaks conventional machine learning wisdom,
00:50:50.940 | like why am I warming up my learning rate before I'm bringing it down, right?
00:50:53.820 | Like if I'm optimizing some surface,
00:50:55.420 | I kind of like want to start kind of high,
00:50:57.660 | move in, and then like maybe anneal it as I get closer to my minimum.
00:51:01.860 | But no one is able to explain why,
00:51:04.900 | till three years later, a paper comes out that kind of like steps through
00:51:09.580 | the specifics of training a transformer model on some data,
00:51:13.340 | like synthetic data with the atom optimizer,
00:51:16.260 | and actually tie it to the layer normalization layers that we just added.
00:51:20.060 | We fixed one problem, we added another.
00:51:21.860 | Right. So up top, we have kind of good gradients.
00:51:25.180 | Right. So on the left here is the gradient magnitude and here's the update magnitude.
00:51:29.020 | So the gradients that are computed and
00:51:30.820 | the updates that are actually applied to the weights.
00:51:32.700 | With warm up in blue,
00:51:34.220 | in red, we have the same thing but without warm up.
00:51:37.420 | And what ends up happening is that gradients go to zero somehow as you train.
00:51:41.820 | It's actually a weird graph because like as you're coming forward in time,
00:51:45.020 | it's like as you're training more and more.
00:51:47.180 | So this is kind of like starting out and then like this is kind of, yeah.
00:51:51.620 | And this is kind of like towards the,
00:51:54.420 | you know, wherever training becomes unstable.
00:51:57.380 | And then your updates also become super high variance.
00:52:00.860 | So they do some math and they kind of bound the update as a kind of
00:52:07.020 | like proportional to the dimensionality of the,
00:52:11.540 | or the square root of the dimensionality of your transformer,
00:52:13.780 | over the input norm that's coming in.
00:52:16.100 | So if your input norm,
00:52:17.460 | like if the size of your activation is like sufficiently large,
00:52:21.940 | your layer norm gradient is going to be completely, completely screwed.
00:52:28.820 | So what they end up doing is like, okay,
00:52:31.780 | so warm up is necessary because it helps you get
00:52:35.780 | the atom optimizer to kind of like move slowly enough at the beginning.
00:52:40.900 | So that we're kind of like saturating the gradients, we're like, okay.
00:52:44.940 | And then when we kind of go full throttle,
00:52:46.700 | like things are generally stable.
00:52:48.780 | The activations norms aren't changing all too much.
00:52:53.740 | They're changing in a predictable way and we can kind of start to handle that,
00:52:56.540 | and then conventional ML kicks in. But it's weird.
00:53:00.340 | And it's also weird that it took three years later,
00:53:02.340 | and some people still don't buy this explanation,
00:53:04.340 | but it's the best explanation I've got to why we need that warm up.
00:53:07.220 | So general wisdom,
00:53:09.700 | you're fine tuning or pre-training a transformer,
00:53:11.980 | warm up for at least five percent of your full training,
00:53:14.660 | and then start decaying. It just helps.
00:53:17.580 | >> Can I ask this? I don't know this paper.
00:53:20.660 | So is there some data dependency or some assumption about what the data will be like?
00:53:26.220 | Because it seems like you said,
00:53:27.700 | hey look, after a while we can relax.
00:53:29.780 | These updates are small or reasonable,
00:53:33.140 | but the world could do a lot to you.
00:53:35.100 | And if you shifted genres or data types,
00:53:38.100 | it would go back into being very unstable.
00:53:40.700 | >> Yeah. So I think in this paper they're looking at what I'll call nice datasets.
00:53:45.300 | They're looking at the Wikitext 2s of the world that are somewhat predictable.
00:53:48.700 | It's all Wikipedia homogenized language.
00:53:51.100 | But even when you're training the modern,
00:53:53.060 | like the really big transformer these days,
00:53:54.980 | even after all of these tricks,
00:53:56.620 | you're still going to have bad batches.
00:53:59.060 | Just really, really unpredictable things that are low likelihood under your model,
00:54:03.660 | that are going to cause big updates in the middle of
00:54:05.100 | training that are going to completely crash your run.
00:54:07.380 | So this happened tons of times while we were training
00:54:09.020 | like the 355 million parameter models.
00:54:11.300 | This happened every time you're training
00:54:13.900 | any big model in like the one million plus parameter range.
00:54:16.940 | So the Luther AI folks,
00:54:18.500 | like this happens all of the time.
00:54:20.180 | The T5 models have this thing in like one of
00:54:24.020 | the notes in like the GitHub repository for like training a T5,
00:54:27.660 | which is like, if training fails,
00:54:29.940 | rewind to the latest checkpoint,
00:54:31.500 | re-randomize your data order and then try again,
00:54:33.820 | and it won't crash.
00:54:35.140 | That's kind of how modern ML or
00:54:38.420 | most modern language models are trained right now.
00:54:42.060 | We don't know how to avoid it yet.
00:54:43.740 | We think it's tied to the data, we can't isolate it.
00:54:45.940 | So why not just re-roll and try again?
00:54:49.340 | Eventually, you'll just keep making progress.
00:54:51.740 | Cool. So question.
00:54:55.980 | >> Back to the graphs,
00:54:57.220 | what do the different colors represent in-
00:55:00.740 | >> Yeah. So the question was,
00:55:02.220 | what the different colors represent.
00:55:03.660 | So up top is blue with the traditional transformer learning rate,
00:55:08.060 | so warm up and then go down.
00:55:10.780 | Red is the no warm up,
00:55:14.420 | let's just start the learning rate high and then taper.
00:55:16.820 | So red is bad, blue is good.
00:55:20.060 | >> [inaudible]
00:55:26.660 | >> Yeah. So the way you can interpret this graph,
00:55:29.020 | and the paper's linked in at the bottom of the slide,
00:55:31.980 | but you can think of the furthest back magnitude is like this,
00:55:36.820 | basically plotting the mean standard deviation of the updates across layers.
00:55:41.540 | The furthest back is like batch zero.
00:55:44.700 | As you get further in,
00:55:47.300 | you get to batch 100 or batch 400 or whatever.
00:55:51.260 | Yeah. Question.
00:55:54.580 | >> I wonder to what extent the warm up and the rate.
00:56:00.060 | >> I think I'm relating to the choice of optimizer.
00:56:05.540 | Because I've run into some problems where I found that using
00:56:11.060 | AtomX of the infinity norm will work because I've got this level of dropout or whatever.
00:56:15.740 | Is there any guidance of all these big hyperparameters that go into this tune,
00:56:21.580 | where if I pull one lever,
00:56:23.580 | I should be pushing one down or choose this optimizer,
00:56:26.380 | I should do another because it feels like,
00:56:28.260 | I mean, taking three years,
00:56:31.580 | it feels a little bit like the Wild West, which is what it is.
00:56:34.820 | >> Yeah. So if I were to paraphrase your question,
00:56:38.540 | it's like, if you decide to change anything about the current recipe,
00:56:42.380 | like change your optimizer, change dropout, change your learning rate,
00:56:45.220 | are there rules of thumb for what else you need to change to get things to work?
00:56:50.060 | No. It is.
00:56:54.820 | So part of why I led with how we got to here,
00:57:00.500 | starting from the historical context,
00:57:02.500 | was to unpack a lot of this folk knowledge.
00:57:05.620 | Because it's still at the point where optimizing these models,
00:57:08.980 | especially as we go bigger,
00:57:10.660 | is still concentrated in the minds and experience of a very small number of people.
00:57:16.820 | Because who's trained a seven billion parameter language model,
00:57:19.460 | or who's trained a 100 billion parameter language model?
00:57:21.900 | Where does that skill set come from?
00:57:23.340 | When you're talking about a training run that cost millions of dollars to develop,
00:57:28.180 | plus however much the compute costs,
00:57:31.340 | how many things are you really going to be trying at the end?
00:57:34.300 | What things can you extrapolate from?
00:57:36.740 | So folks at OpenAI have definitely done
00:57:38.740 | like scaling laws research where they're trying
00:57:40.500 | these different things in some bounded search space.
00:57:43.420 | But if you were to invent like a brand new optimizer,
00:57:46.100 | it kind of looks at like maybe second,
00:57:47.780 | third, fourth order moments of your gradients,
00:57:51.300 | maybe do something fancy relative to how things are changing over time.
00:57:55.860 | And you were to just try and apply it to the biggest language model you could train,
00:58:00.740 | I have no idea what I would tell you in terms of like what things you should change,
00:58:05.100 | beyond like some obvious things,
00:58:06.220 | like maybe don't set the learning rate to be ridiculously high.
00:58:09.140 | Starting now.
00:58:10.340 | >> If you have a batch of data that's like,
00:58:16.380 | let's just say during training you come across a bad batch of data that happens to
00:58:20.300 | cause like the destabilization with the gradients.
00:58:23.700 | And then you rewind back to your checkpoint and you take that same exact batch of data,
00:58:28.020 | but instead of running it through training when gradients are enabled,
00:58:32.580 | if you just run it through inference,
00:58:34.460 | will that bad batch of data have caused like
00:58:37.500 | anomalous behavior from the model during inference,
00:58:41.060 | or is it strictly just a back propagation issue?
00:58:43.580 | >> So when we debug or when we debug this,
00:58:47.140 | so we a couple of years ago trained like some,
00:58:50.980 | by today's standards, really,
00:58:52.500 | really small language models,
00:58:53.660 | but like 124 million to 355 million scale.
00:58:56.940 | We were noticing this problem.
00:58:59.140 | The way we debugged it was like looking at the gradients,
00:59:02.140 | which didn't tell us much, but then we just looked at activation norms per layer,
00:59:06.140 | and that's how we actually debug this, right?
00:59:08.780 | So looking at the forward pass,
00:59:10.300 | looking at the magnitudes of each layer,
00:59:12.420 | like where we thought we could possibly be overflowing or underflowing,
00:59:16.380 | that's exactly how we debugged it.
00:59:18.980 | But we didn't debug it at the batch level.
00:59:21.100 | We debugged it as a function of time, right?
00:59:23.540 | Because a single batch isn't going to perturb everything.
00:59:26.260 | A series of batches,
00:59:28.060 | like maybe two, three, who knows how many,
00:59:29.980 | eventually you're going to fall under some trajectory where things get bad,
00:59:33.300 | your activations blow up.
00:59:34.900 | So we would be able to deterministically be running to that checkpoint,
00:59:38.060 | then deterministically step through training and log every activation,
00:59:42.540 | which is expensive, but that's how we were able to get to the bottom of the problem.
00:59:46.060 | But I don't think we have tools for actually figuring out which batch of data was,
00:59:51.100 | or which sequence of batches of data were the actual triggers for that behavior.
00:59:55.860 | >> I guess I was just curious specifically about if the same data that causes
01:00:00.020 | destabilization in training can cause anomalous behavior,
01:00:05.060 | or just like a normal form of pass?
01:00:07.300 | >> It probably would.
01:00:09.380 | There's some more recent work about how to quantize these models,
01:00:12.860 | like how to get a transformer that's trained with like 16-bit precision to
01:00:17.060 | like train or to run with like eight bit precision by like intelligently like
01:00:21.740 | bucketing floats from Tim Detmers who's a PhD student up at UW.
01:00:27.140 | He has this theory on something called outlier features that show up in
01:00:30.860 | these really big models that kind of try and get at this,
01:00:33.420 | but more of an art than a science right now.
01:00:36.980 | Yeah. Okay. So are we done now that we fix this like layer norm stuff,
01:00:44.260 | the learning rate stuff,
01:00:45.660 | all of the stuff to get the transformer to work?
01:00:47.660 | Kind of. Right. So like over the last few years,
01:00:50.040 | like training has been stable,
01:00:51.420 | people want to like milk the most of the transformer,
01:00:54.500 | you know, especially as they scale up.
01:00:57.540 | So they do a couple of things.
01:00:59.900 | So one, when you're training and potentially and you're projecting to queries and
01:01:04.460 | keys at sufficient scale,
01:01:07.220 | the bias term in each like linear layers like WX plus B,
01:01:10.580 | you can get rid of the Bs because like it's not really doing
01:01:12.540 | anything and it saves a little bit of compute.
01:01:14.460 | So let's get rid of them. Like that's like the first thing to throw out.
01:01:17.700 | There are different activations that have been invented and different types
01:01:21.020 | of like cool ways to just better fit your data.
01:01:25.820 | So there's this like swish glue,
01:01:27.220 | so a gated linear unit actually
01:01:29.820 | defines like a separate weight matrix as part of the activation.
01:01:33.940 | And then a swish is like a sigmoid activation
01:01:39.740 | that applies to one part of the weight and not the other.
01:01:41.500 | I've code for all of this. It works better.
01:01:44.460 | This is actually the activation of choice now in most transformer implementations.
01:01:49.420 | So Lama was trained with this,
01:01:51.020 | Palm was trained with this, works really well.
01:01:55.060 | One thing that folks noticed is that moving the layer norm to happen,
01:02:00.860 | you know, before you actually feed it through the attention or the MLP layers instead of
01:02:07.340 | after is a more stabilizing force is actually kind of important.
01:02:11.780 | Also a layer norm has trainable weights.
01:02:14.660 | So some papers decide to be like,
01:02:16.580 | you don't actually need these trainable parameters for mean and variance.
01:02:20.980 | You can actually just like divide by the mean square
01:02:23.500 | or the RMS of like your tire activation feature.
01:02:27.100 | All things to just get rid of
01:02:28.940 | irrelevant flops because we're training massive models and we're
01:02:31.700 | trying to do the bigger model on the compute that we have.
01:02:35.900 | Oh, yeah, here's the code for like a swish glue activation and an RMS norm.
01:02:40.860 | So swish glue is like,
01:02:42.220 | so the Silly is like basically a sigmoid,
01:02:44.380 | a projection layer is basically saying like let's take
01:02:46.580 | this input feature projected into like two separate chunks.
01:02:50.260 | One chunk becomes like a gating value kind of like
01:02:53.260 | in a gated recurrent unit in like the RNN literature.
01:02:57.300 | One becomes the actual value,
01:02:59.540 | you apply the sigmoid to the gate and then multiply it element-wise with
01:03:03.100 | the value and you get your new thing, works really well.
01:03:06.620 | An RMS norm is like literally just dividing by the norm
01:03:10.180 | of the vector instead of like trying to learn anything fancy.
01:03:13.940 | Cool. This is what the modern transform looks like.
01:03:17.820 | So that's it for the evolution of the transformer.
01:03:21.940 | As far as I know,
01:03:23.420 | nothing in the last two weeks have like changed drastically from this.
01:03:26.860 | In the last two weeks. To what extent,
01:03:29.900 | let's say we are doing a fine-tune instead of like the full train,
01:03:34.500 | or we're doing like lower on top of it.
01:03:36.660 | How would we still want to follow these kinds of guidelines or these specific to
01:03:41.220 | just doing all of the data doing a full pre-train?
01:03:45.140 | Yeah. So question is,
01:03:46.860 | what of this do we really need if we're fine-tuning,
01:03:50.420 | or if we're kind of doing like parameter efficient fine-tuning,
01:03:52.460 | like is this only necessary for pre-training?
01:03:54.380 | So I've started using the Swish Glue instead of like any other activation like everywhere,
01:04:00.540 | even for like a two-layer MLP,
01:04:02.340 | tends to work better.
01:04:04.540 | So take that with a grain of salt.
01:04:07.340 | Everything else you can probably not care about.
01:04:10.540 | The RMS norm, the pre-norm is probably just like a general rule of thumb if you're adding
01:04:16.180 | any transformer layers just because it is like demonstrably more stable,
01:04:19.980 | but other than that, no.
01:04:22.900 | Other questions here before we move to how to train on lots and lots of compute.
01:04:30.300 | Cool. So let's talk about training at scale.
01:04:35.500 | So I'll start with a story, my story.
01:04:40.940 | Okay. So I am not old,
01:04:44.740 | but I have seen a few things as far as language models have gone.
01:04:50.340 | So like 2018 is when I think did my first deep learning tutorial.
01:04:55.940 | I trained a MNIST like the typical like two-layer,
01:04:59.740 | four-layer MLP for classification.
01:05:03.060 | There's actually a line there.
01:05:04.940 | It's the 100,000 parameter line. That's 2018.
01:05:09.660 | As I kind of start my PhD in 2019,
01:05:12.700 | I'm doing more NLP stuff.
01:05:14.220 | I'm looking at like word vectors,
01:05:15.700 | RNNs, some more sophisticated things.
01:05:18.500 | I'm getting up to a million parameters.
01:05:21.460 | 2020, I kind of branch out from like the small NLP stuff I'm doing to like
01:05:26.100 | more intensive NLP so looking at tasks like summarization,
01:05:30.420 | training models with like 10 million parameters.
01:05:33.580 | Then by 2021, like the biggest models I was training,
01:05:39.580 | was like when I switched into multimodality robotics,
01:05:41.660 | looking at visual question answering, 18 million parameters.
01:05:45.460 | At the time, the standard pipeline for me,
01:05:47.660 | and I think this was the standard pipeline for a lot of
01:05:49.340 | grad students that I talked to then,
01:05:52.020 | is like I'd be able to train most of my things on
01:05:55.140 | one GPU or even my laptop CPU for like a maximum of a few hours.
01:06:01.220 | I got it is what a training run would take,
01:06:03.620 | at least for like most of the things I was doing on a day-to-day.
01:06:07.020 | But in 2021, Percy's like,
01:06:12.620 | "Hey, this GPT-3 thing seems cool.
01:06:15.260 | Let's at least figure out if we can get an academic lab to
01:06:18.140 | like try and train a GPT-2 like the earlier generation."
01:06:21.420 | So clocking in at like 124 million parameters,
01:06:24.700 | which is notably an order of
01:06:26.700 | magnitude bigger than anything I trained at the time.
01:06:28.860 | So why I decided to do this is still beyond me,
01:06:31.260 | but I learned a lot of useful things.
01:06:33.300 | One of the useful things that I learned is that training
01:06:37.860 | a 124 million parameter model on
01:06:40.940 | a decent GPU that we had access to at the time,
01:06:44.300 | with a batch size greater than four would go out of memory,
01:06:47.420 | which is bad because a batch size of four is small,
01:06:50.340 | and we wanted to ideally train with like a batch size of 512.
01:06:53.660 | So there was a simple trick,
01:06:56.460 | and it's called gradient accumulation,
01:06:57.980 | which is like I'm going to run batch sizes of
01:07:00.140 | four however many times it takes to get into 512,
01:07:04.060 | and then do an update after
01:07:05.740 | processing all of those batches sequentially.
01:07:07.860 | So I'm just going to keep accumulating the gradients,
01:07:10.140 | and PyTorch makes that really easy.
01:07:11.580 | It's just a for loop and an if statement.
01:07:14.060 | But if you do the math,
01:07:16.860 | it's like 100 days to train on that single GPU for 400,000 steps.
01:07:20.900 | So how do we go from this clock of 100 days to something reasonable?
01:07:26.340 | That's what we're going to talk about.
01:07:27.940 | So with the scaling toolbox,
01:07:32.180 | at least as far as we were concerned,
01:07:37.300 | ended up looking like three different parts across
01:07:42.780 | 16 GPUs because Percy and Chris Ray,
01:07:45.460 | and I think Chris and Dan and Chris Manning,
01:07:48.900 | and the NLP group decided to invest upfront in
01:07:51.220 | like really powerful GPU machines,
01:07:53.660 | so we could actually train on like 16 GPUs at once.
01:07:57.100 | For reference, 16 GPUs on AWS,
01:08:02.260 | just rent on an hourly basis is 56 bucks an hour now.
01:08:06.580 | Fifty six bucks an hour if you want to just like sit on them.
01:08:11.460 | But if you're willing to like let anyone who has
01:08:15.180 | the money and like wants to sit on them like preempt you,
01:08:18.140 | you can get them for 16 bucks an hour.
01:08:19.980 | So like across four days like that's not the worst.
01:08:23.180 | It's like not great, but totally doable.
01:08:27.340 | So the scaling toolbox we ended up looking at was data parallelism.
01:08:32.420 | You can think about this as like literally just divide and conquer.
01:08:35.220 | How do I just parallelize work across all of these GPUs instead of one?
01:08:40.860 | Mixed precision training, and we're going to talk a little bit about what that means.
01:08:46.820 | Then this interesting idea called zero redundancy,
01:08:49.620 | which is about minimizing the memory footprint of training.
01:08:52.820 | Then later on as you want to scale up to hundreds of billions of
01:08:58.140 | parameters on 256, 512, 1024, 2048 GPUs.
01:09:05.340 | We'll talk like there are things that come in handy like model parallelism.
01:09:09.700 | There are things to consider like hardware and software limitations.
01:09:14.700 | But some of you might be here looking at me which is like, okay,
01:09:18.740 | do I need any of this stuff if I'm not training really big models?
01:09:21.060 | Like if I'm just fine tuning stuff.
01:09:23.620 | A lot of these tips and tricks like you may not have access to 100 GPUs or
01:09:31.020 | even eight, but you might have access to two or four, comes in handy.
01:09:35.460 | A lot of the ideas here are still ideas that I'm using when I'm training stuff on
01:09:40.100 | my laptop or when I'm trying to run inference with the latest big model
01:09:43.020 | that is publicly released, so it's useful.
01:09:45.140 | But please ask questions if things become too hazy or too not useful.
01:09:52.180 | >> [INAUDIBLE]
01:09:53.700 | >> Mm-hm.
01:09:55.060 | >> For people relying on Colab,
01:09:58.180 | data parallelism might actually not help.
01:10:01.780 | >> So Colab, yeah, Colab, you're still limited to a single GPU.
01:10:04.980 | >> And I'm guessing zero redundancy might help.
01:10:08.940 | >> So mixed precision would help, kind of, definitely for running inference.
01:10:15.660 | And zero redundancy would also help running inference.
01:10:18.140 | >> What's the zero redundancy?
01:10:20.380 | >> So zero redundancy has an add-on that they wrote up in a paper later called
01:10:24.940 | zero infinity, which is like, what if I didn't put all my weights on the GPU at
01:10:28.500 | once, what if I put some of them in CPU RAM or even in NVMe SSD storage?
01:10:35.100 | So actually turning your laptop into a more powerful workhorse than a Colab GPU.
01:10:39.460 | Cool, so this is a toy example kind of going through data parallelism.
01:10:47.340 | We're running low on time-ish, so this is MNIST with an MLP.
01:10:54.700 | We're trying to do classification.
01:10:56.340 | It's kind of the typical PyTorch workflow.
01:10:59.180 | I'm going to define an N dot module.
01:11:01.580 | I'm going to define a batch size, a data loader that's going to load from
01:11:04.420 | the TorchVision data set.
01:11:06.020 | And then I'm just going to run lots and lots of gradient steps.
01:11:10.420 | The idea here is, how do we parallelize this across multiple workers,
01:11:17.620 | multiple GPUs?
01:11:19.140 | Well, that batch size you see there is totally divisible, especially given that
01:11:24.420 | what we're doing at the end when we kind of compute the loss is just take an average.
01:11:28.220 | An average of averages is still the average.
01:11:31.860 | That's the idea we're going to work with.
01:11:33.420 | The mean of means is still the global mean.
01:11:36.460 | So just like in CPU land where you can kind of think about SIMD instructions,
01:11:41.020 | like single instruction, multiple data.
01:11:42.860 | Right, so this is kind of how most graphics and
01:11:45.500 | media operations work on your laptops.
01:11:48.260 | We're going to now think about the SPMD paradigm.
01:11:51.220 | I'm going to write one program, and it's just going to automatically scale to
01:11:54.980 | being split across our running machines,
01:11:56.900 | because we're going to split the data across multiple machines.
01:11:59.460 | It seems hard, but as of PyTorch 1.4,
01:12:03.140 | a lot of the hard parts are taken care of for you.
01:12:06.220 | These are the only lines you need to change in the implementation.
01:12:09.060 | Two of them are import statements.
01:12:11.220 | So the first thing we're going to do is we're going to just create something
01:12:14.780 | called a distributed sampler, which is going to automatically partition our data
01:12:17.980 | across the number of workers we define up front.
01:12:20.580 | Right, we're defining a world size of eight, so
01:12:23.820 | that means we're training on eight GPUs.
01:12:25.500 | So this is going to partition our data into eight different subsets that each
01:12:28.380 | worker gets to go through.
01:12:29.540 | We're going to wrap our nn.module with this nice little wrapper,
01:12:34.860 | this distributed data parallel wrapper,
01:12:36.980 | which is going to sync the gradients for us behind the scenes.
01:12:40.340 | And then we're going to run this with a special command called Cortron,
01:12:45.940 | which is just going to inject a bunch of environment variables so
01:12:48.660 | that we can get some statistics about our local rank,
01:12:53.020 | who's the guy who should be printing stuff to the screen,
01:12:55.300 | who's the guy who should be logging stuff, where each worker lives.
01:13:00.060 | And that's about it, and you can do all of this, and
01:13:05.260 | you just parallelize naively across 16 GPUs.
01:13:07.860 | You get not quite a 16x speedup, because there is some overhead from communication.
01:13:13.460 | It's like seven days, that's cool.
01:13:16.140 | It was not good enough, because we were trying to train lots of models
01:13:19.620 | reproducibly, five seeds for like ten different model types, so like 50 models.
01:13:26.060 | So we needed to go a little faster than this.
01:13:28.540 | So let's talk about memory footprints.
01:13:30.900 | When I am training any model with an atom optimizer,
01:13:35.540 | how much memory does just storing that model and the optimizer weights take up?
01:13:43.980 | So in 32-bit precision, our model's going to have parameters,
01:13:50.260 | where each parameter is stored with 32 bits, that's a float.
01:13:54.100 | Gradients, 32 bits.
01:13:56.380 | Now your optimizer is also going to do this weird thing where it's going to have
01:13:58.380 | a copy, its own separate copy of the parameters,
01:14:00.540 | like kind of duplicating a little bit of work there.
01:14:02.940 | That's also 32 bits.
01:14:04.660 | And then atom tracks momentum and variance, like the first and
01:14:08.540 | second order of the gradients.
01:14:10.940 | So that's another 64 bytes, or bits, right there.
01:14:14.580 | So the lower bound on static memory, just like storing this stuff on a GPU,
01:14:19.300 | is 20 bytes times the number of parameters that you have.
01:14:21.620 | This doesn't include activations at all for these larger transform models.
01:14:26.980 | If I want to keep around every buffer, like every intermediate matrix
01:14:31.380 | as I pass it through the network, that takes up way more space.
01:14:34.420 | But this at least gives us something that we can reason about.
01:14:38.700 | The training implications of this is that if I want to fit a model with
01:14:42.260 | a billion parameters, that's going to take about 18 gigs resting,
01:14:46.140 | 31 gigs of GPU RAM with a batch size of one.
01:14:49.900 | Which is problematic, because most GPUs then cap out at 24 gigs.
01:14:56.100 | The really expensive ones now have like 40 or 80, but this is still bad.
01:14:59.740 | 175 billion parameters would take three terabytes of RAM,
01:15:04.980 | not storage, just RAM, without activations.
01:15:08.180 | With activations, it's probably looking like ten terabytes.
01:15:10.860 | Good luck.
01:15:11.380 | >> [INAUDIBLE]
01:15:13.740 | >> The numbers in bold are batch size one.
01:15:15.540 | Numbers not in bold are just putting it on the thing.
01:15:19.180 | And things you should know about floats,
01:15:25.300 | it was a standard defined in this IEEE document.
01:15:29.020 | You have a one bit sign, eight bit exponent, 23 bit scientific notation,
01:15:34.620 | like all the stuff that happens after the exponent.
01:15:37.460 | Wide range, up to 1E38.
01:15:39.900 | And the question is, do you need that range?
01:15:41.940 | Answer is, kind of, but not really.
01:15:45.060 | So the mixed precision memory footprint.
01:15:47.220 | If I'm training a model in mixed precision, what that means is that I'm just
01:15:50.700 | going to run everything in a forward pass and
01:15:53.900 | part of the backwards pass in 16 bit precision instead of 32 bit precision.
01:15:59.940 | Notably, what that means is now I'm storing my parameters in 16 bits,
01:16:03.380 | my gradient in 16 bits.
01:16:05.660 | All of those intermediate activations that take up lots and
01:16:08.820 | lots and lots of memory, especially as you go bigger, are halved, which is great.
01:16:12.540 | But the weird part about mixed precision is not everything is mixed precision.
01:16:17.900 | So your optimizer to stably update your model still needs the 32 bit
01:16:21.700 | parameter copies, your 32 bit momentum, 32 bit variance.
01:16:25.660 | But you've dropped four bytes, and those four bytes are kind of useful.
01:16:31.260 | Yet training with mixed precision, at least a couple years ago, and
01:16:34.940 | it's still mostly true now, is still way faster than training with full precision.
01:16:40.660 | And the reason for that is most NVIDIA GPUs,
01:16:45.580 | starting with the Volta cards, started shipping with these things called tensor cores.
01:16:51.860 | Tensor cores are basically the individual logical units on a GPU that are responsible
01:16:56.940 | for matrix multiplies.
01:16:58.660 | Your GPU is really good at accelerating neural network training because it's
01:17:01.860 | really good at doing matrix multiplies.
01:17:04.020 | These things are optimized for 4x4 to 16x16 size shards.
01:17:11.660 | If you're training in 16 bit precision,
01:17:13.460 | you can actually end up using way more tensor cores than if you were using
01:17:17.580 | 32 bit precision.
01:17:18.460 | And so you're able to get a ton of speed ups just because you're able to tap
01:17:23.420 | into the underlying hardware of your system more frequently.
01:17:26.460 | As of the Ampere style cards, like the A100s or
01:17:31.380 | the more recent 3090s, 3090TIs, 4090s.
01:17:36.140 | Those start shipping with these cores that are able to do float 32 precision, but
01:17:41.420 | are still even faster for 16 bit precision.
01:17:43.940 | So when you can, train with 16 bit precision.
01:17:46.740 | All right, this shaves a day off of our small scale training.
01:17:54.020 | This shaves off way more, especially as you go bigger.
01:17:58.700 | And now the final bit is how do we eliminate the redundancies?
01:18:02.460 | Right, so in standard data parallelism.
01:18:04.380 | Yeah.
01:18:06.020 | >> I have a question.
01:18:07.020 | So why do you need the 32 bit precision for the optimizer, but
01:18:10.220 | what is it for the model?
01:18:11.140 | >> So when you are estimating the gradients, precision matters more.
01:18:16.500 | You want that full range.
01:18:18.300 | Specifically, you want those 23 bits that kind of correspond to
01:18:23.700 | everything that has data to be meaningful.
01:18:27.460 | Because while the full range of float 32 is really, really big,
01:18:32.380 | it actually can't be super precise in the 0, 1 range, for example.
01:18:38.660 | So you kind of want as much there as possible to kind of ensure precision.
01:18:44.140 | Okay, so zero redundancy, standard data parallelism.
01:18:48.940 | You're basically storing everything on each GPU.
01:18:51.780 | Key idea is I don't need to store everything on every GPU.
01:18:56.380 | I just need to store some things on every GPU.
01:18:59.100 | So the model gets to stay on each GPU cuz the model has to do all the work.
01:19:03.140 | But the gradients, I can just split across the number of devices that I have.
01:19:06.140 | So half my gradients, if I have two GPUs, go on one device,
01:19:09.580 | half of them go on the other device.
01:19:11.700 | Same with the optimizer states.
01:19:12.900 | Half of them go on one device, half of them go on the other device.
01:19:15.340 | With this model of zero redundancy,
01:19:17.460 | you're actually not adding any extra communication cost.
01:19:20.540 | You just get free memory because you're just intelligently partitioning
01:19:24.620 | things across devices.
01:19:26.780 | Notice that this scales, you use less and
01:19:31.620 | less memory as you add more and more machines, right?
01:19:34.780 | So this is kind of like the biggest trick to start training 1 billion to
01:19:37.940 | 10 billion parameter plus models.
01:19:40.780 | And now when you add this, you're at three days.
01:19:43.260 | >> So would you have to tell it what,
01:19:48.660 | cuz this would require things being slightly out of sync to optimize.
01:19:53.100 | Would you have to- >> So this actually doesn't require
01:19:55.060 | anything out of sync, because in a backwards pass with distributed data
01:19:58.700 | parallel, the individual updates already have to be synced across processes.
01:20:05.700 | If you take the loss across average and across all processes,
01:20:08.860 | that means that the gradients you apply have to be transferred as well.
01:20:12.420 | So this just basically does that work for you.
01:20:15.860 | We're gonna wrap up.
01:20:20.100 | You hit a communication wall, matrix multiplies,
01:20:22.300 | stop fitting on a device, so you start charting them.
01:20:24.860 | And then you have to start scheduling things wisely, and yeah, great.
01:20:27.980 | Fine tuning inference, there is a great library that you should use
01:20:32.140 | called Peft from Hugging Face, it's great, and that's it.
01:20:34.780 | [BLANK_AUDIO]