back to indexStanford XCS224U: NLU I Fantastic Language Models and How to Build Them, Part 2 I Spring 2023
00:00:09.400 |
The plan is to finish up our review of core information retrieval stuff. 00:00:15.280 |
The focus will be on neural information retrieval, 00:00:25.240 |
and Sid is going to help us talk again about how to 00:00:31.200 |
So let's dive in. We'll start by using our big handout here, 00:00:41.520 |
That's right. I had a couple more metrics that I wanted to show you. 00:00:45.680 |
So last time we talked about how assessment in the space of IR should be multidimensional. 00:00:56.080 |
We are going to circle back and talk about these other dimensions, 00:00:59.200 |
which I regard as absolutely crucial in this space. 00:01:03.040 |
But with that said, we did dive into different metrics. 00:01:16.720 |
That is, is there a relevant document above k? 00:01:26.560 |
has success because there is a star at 2 or above. 00:01:30.240 |
D2, that ranking also has a success of 1 because there is a star at 2 or above, 00:01:39.760 |
And you can see already that it's coarse-grained because D1 and D2 are differentiated, 00:01:50.140 |
Reciprocal rank is a little bit better in the sense that it's more or less just 00:01:55.940 |
registering whether there's a star at or above k, 00:01:58.900 |
except now we are sensitive to the top most ranked one. 00:02:03.020 |
So for example, D1 here has an RR at 2 of 1 because there is a star in first place. 00:02:09.820 |
Whereas D2 has 1 over 2 because the first star is in second place. 00:02:19.940 |
So pretty coarse-grained but very intuitive and sometimes success and RR are 00:02:25.500 |
good metrics in the sense that you kind of just want to know for 00:02:32.380 |
And especially if you only have one relevant document per query, 00:02:38.500 |
And then RR will just be a little bit more nuanced. 00:02:44.180 |
the classic accuracy style metrics in this space. 00:02:47.960 |
The differentiator here from the previous ones is that 00:02:50.960 |
these are going to be sensitive to multiple stars. 00:02:53.620 |
So if you have more than one document that's relevant to your query, 00:03:00.440 |
So we have this notion of a return value that is just the set of documents k or above. 00:03:12.540 |
what percentage of the things at or above k are relevant? 00:03:17.260 |
And that's precision in the sense that if you picked k, 00:03:20.380 |
you're looking at the set of documents and you want to 00:03:22.540 |
know how many of them have stars relative to the total. 00:03:25.900 |
Or like the reverse of precision would be like which ones 00:03:29.340 |
are kind of imprecise as predictions because there's no star there. 00:03:33.060 |
And then recall is kind of the dual of that and it says for my chosen k, 00:03:37.920 |
how many of the stars made it up to k or above? 00:03:41.940 |
And the opposite of that would be like how many stars are lingering down below? 00:03:46.940 |
So you can see here because of the numerator that we're going to differentiate 00:03:50.740 |
systems now based on how many stars are at k or above. 00:04:02.660 |
For D2, it's 1 out of 2 because just half of them have a star. 00:04:11.820 |
Recall is very similar but now the denominator changes, right? 00:04:15.140 |
So the recall at 2 for this first one is 2 out of 3. 00:04:28.420 |
poor D3 has not fared well in our ranking so far. 00:04:40.860 |
Because now it's got all three of its stars at 5 or above. 00:04:46.860 |
even though they've got some high stars up there, 00:04:54.820 |
And that is maybe something that you want to watch out for because people kind of 00:04:58.580 |
innocently choose these k values when they're evaluating systems. 00:05:01.580 |
And I just showed you that that could really impact the ranking of systems. 00:05:09.180 |
you know, it's hard to imagine since there are only six documents. 00:05:12.060 |
But if it was a lot of work to travel down to our chosen k, 00:05:17.300 |
this would obscure the fact that we might pick as our winner 00:05:20.960 |
a system that had all the stars more or less at 1,000. 00:05:24.500 |
And the other systems which have their stars at the top of this ranking, 00:05:29.740 |
those might be diminished with such a high k. 00:05:33.320 |
And so that kind of gets you into the role of thinking, 00:05:37.140 |
What is the cost of them scanning down a list of ranked results and things like that? 00:05:41.020 |
And that's where I want you to be when you think about these metrics. 00:05:44.660 |
What are you trying to solve out there in the world? 00:05:49.200 |
What is the cost of reviewing examples and so forth and so on? Yeah. 00:05:54.300 |
Well, the neural IR models that we're going to kind of solve this problem of, right, 00:05:59.340 |
because right now everything's based on the presence or not of a word, 00:06:03.500 |
rather than kind of maybe a- either a longer meaning or, 00:06:07.540 |
um, like the quality of the relevance, however we define it. 00:06:11.140 |
Like maybe it only says the word once but actually has the best information afterwards. 00:06:17.420 |
neural also going to be based on presence or not of words? 00:06:20.500 |
That's a great question. Ah, wait, we should be careful. 00:06:23.060 |
So yeah, I think for the first part of your question, 00:06:25.740 |
I want to say the neural IR models are overall going to be better. 00:06:34.140 |
It won't directly impact this because these stars after all aren't about terms. 00:06:39.100 |
This is about whether a whole document was relevant to a query. 00:06:41.980 |
You should imagine that the background process is like some team of humans went through and said, 00:06:46.900 |
okay, you searched for Bert and now I'm going through documents and saying, 00:06:55.900 |
But I think you're right in your core intuition. 00:06:58.500 |
Term-based models are going to be kind of brittle. 00:07:01.260 |
And if we have hard query document pairs, they might miss them. 00:07:13.580 |
This is an example of why search is a hard NLU problem. 00:07:19.540 |
what compounds protect the digestive system against viruses, 00:07:25.540 |
but there is zero relevant term overlap between query and document. 00:07:29.860 |
All of the connections that we want to make are deeply semantic connections. 00:07:34.620 |
And I do think that that is why neural IR models have pulled ahead for 00:07:40.100 |
accuracy style assessments trying to be careful as you'll see. 00:07:45.300 |
I have one more metric which is average precision. 00:07:58.500 |
Average precision, notice it has no K. And the reason it has no K is that we're going 00:08:07.200 |
different Ks here where there is a relevant document. 00:08:14.200 |
we're going to choose that as a K. And we're going to sum up 00:08:17.780 |
just those precision values and divide it by the number of relevant documents. 00:08:23.800 |
Here's an example. Same three rankings that we had before, 00:08:27.700 |
and what I'll show you are these precision calculations. 00:08:35.380 |
And so we accumulate the precision values for one, two, and six. 00:08:52.900 |
which is reassuring, and we're also checking at every level. 00:08:55.780 |
So it's not going to have that sensitivity I showed you before where 00:08:59.500 |
the choice of K dramatically impacts which rankings we 00:09:02.980 |
favor because now we're kind of looking at all of the ones chosen by the ranking. 00:09:19.500 |
we do the same thing at positions three, four, 00:09:21.780 |
and five, and notice interestingly that D3 has a pulled ahead of D2. 00:09:28.340 |
That's less surprising to me in the current context because D2 is kind of good and kind of not. 00:09:40.220 |
but the other two stars are way at the bottom of our ranking, 00:09:43.500 |
whereas at least D3 kind of put them all at least not literally at the bottom. 00:09:49.460 |
Whereas D1 looks like just a slam dunk winner here. 00:09:53.300 |
I mean, it simply has most of the stars right at the top. 00:10:01.540 |
So if I just stepped back from this little example, 00:10:04.620 |
I would say that average precision is kind of nice in terms of giving what looks to me like 00:10:09.700 |
a pretty nuanced picture of these three rankings. 00:10:13.500 |
I think that's all of the accuracy style metrics. 00:10:19.260 |
Of course, there are others that you'll encounter. 00:10:21.100 |
Some are sensitive to the numerical like the float value, 00:10:24.620 |
because sometimes you have not just a one or a zero, 00:10:26.660 |
a star or not, but rather a float value for relevance. 00:10:29.860 |
There are lots of versions that of course average these over sets of queries. 00:10:35.740 |
but underlyingly that's just some kind of arithmetic average of these scores. 00:10:42.000 |
Are there questions I can answer about these metrics? 00:10:54.980 |
What's it called? Like the discounted cumulative gain, 00:10:59.220 |
and it is the sum of all of the scores divided by something or other. 00:11:12.260 |
you know, have human labels and then you can take the precision, 00:11:17.060 |
or not precision, maybe just position-weighted combination of human labels. 00:11:29.220 |
And then you're just observing that very often for these datasets, 00:11:34.820 |
and then average precision is one way of aggregating over the labels we might have collected. 00:11:45.860 |
but like here's a partial list of things you could think about. 00:11:48.780 |
Which metric? Fundamentally, there is no single answer. 00:11:52.060 |
Is the cost of sc- scrolling through k passages low? 00:11:58.540 |
Because you don't care whether it was a position nine or position one, 00:12:02.060 |
what you really care about is that the user is kind of 00:12:04.660 |
confronted with the success that they can easily find. 00:12:10.780 |
Are there multiple relevant documents per query? 00:12:15.540 |
you probably shouldn't use success at k or rrk, 00:12:19.580 |
because they're only sensitive really to one star. 00:12:22.620 |
And if you went to the trouble of getting multiple stars, 00:12:25.660 |
you know, why have your metric be insensitive to that? 00:12:28.580 |
So that seems clear. Is it more important to find every relevant document? 00:12:35.100 |
That would be a case where maybe human review is cheap, 00:12:38.460 |
or the cost of missing an example is hugely expensive. 00:12:46.220 |
Conversely, if you just need to find some relevant things, 00:12:54.140 |
or because it's just good to know about them, 00:12:57.820 |
Because then all you really care about is that near the top are some relevant things. 00:13:03.740 |
F1 at k is the harmonic mean of precision and recall. 00:13:10.340 |
And that can be used where there are multiple relevant documents, 00:13:13.260 |
but maybe the relative order above k doesn't matter. 00:13:16.900 |
That's just one perspective on what I mean when I say we're combining precision and recall. 00:13:22.500 |
And then finally, average precision of the ones I've showed you, 00:13:26.140 |
will give you the most fine-grained distinctions of the metrics, right? 00:13:33.820 |
Precision because it aggregates over those values, 00:13:38.660 |
So that looks like an awfully good way to really get a fine-grained ranking of systems. 00:13:45.020 |
And then finally, I'm going to talk about this a bit later. 00:13:52.260 |
This is a paper that I did with a team recently of researchers here and at IBM. 00:13:56.780 |
What we're seeing here is a kind of post-hoc leaderboard. 00:14:02.860 |
our complaint is that there are no leaderboards that 00:14:05.180 |
really do anything beyond measuring accuracy style things. 00:14:08.660 |
But if you did go through the literature as we did here and find a lot of systems, 00:14:13.900 |
you can see that they vary widely along other dimensions. 00:14:19.740 |
one of our rankings, goes from 19 to 37 or 39 or something. 00:14:24.740 |
So you say, okay, but then look just to the right of that at the query latency. 00:14:31.660 |
look how much time I have to spend versus down here where 36, 00:14:38.100 |
That is absolutely something that will matter to the search experience of users. 00:14:42.820 |
There is almost no way they're waiting around for 691 milliseconds, for example. 00:14:50.660 |
Right. If you care about space footprint and you will if you are indexing the web, 00:14:55.220 |
some of these have tiny little indices and then, uh-oh, 00:15:03.180 |
Right. So now if you need to hold it in memory, 00:15:11.660 |
So BM25, it has no hardware requirements at all. 00:15:17.880 |
Whereas these models down here that have these really high MRR scores, 00:15:30.700 |
Then of course, I hope you're thinking about this. 00:15:34.260 |
So then what is the best combination of all these things? 00:15:38.140 |
Well, it depends on how much money you have and how much time you have, 00:15:43.620 |
And so the best pitch I can make to you is that as you evaluate systems, 00:15:50.140 |
what matters, and conduct- construct your evaluations on that basis. 00:15:54.800 |
That's gonna be a big theme of the course later on. 00:16:00.520 |
for your papers, are thinking about assessment. 00:16:04.480 |
I should have a whole section about my philosophy of assessment here and not 00:16:07.720 |
just fall into F1 or fall into success at K or whatever is relevant. 00:16:34.140 |
but vastly better in terms of their performance. 00:16:39.160 |
You know, this is like the Pareto frontier as they call it. 00:16:45.040 |
choose any that are off the frontier no matter what your values. 00:16:48.520 |
And obviously, you can see that to favor this model, 00:16:51.720 |
there are gonna have to be other dimensions that we care about beyond cost and MRR, 00:16:56.040 |
because otherwise, that's just not a choice you would make. 00:17:01.920 |
there are hidden dimensions that need to be teased out that would 00:17:04.280 |
show that that ANS model is the best relative to. 00:17:15.280 |
Neural IR. First, we'll start with cross-encoders. 00:17:22.760 |
Okay. Here, just imagine I have a huge transformer. 00:17:29.640 |
I just concatenate the query and the document together, 00:17:44.000 |
Enormously powerful to this comment from before, 00:17:54.600 |
every possible interaction between query and document. 00:18:05.740 |
I'm assuming that our dataset looks like this. 00:18:18.740 |
What I'm depicting on the left here is a model we could summarize like this. 00:18:28.980 |
and process them, and we retrieve this representation here, 00:18:36.020 |
We feed that through a dense layer that does our scoring, 00:18:39.700 |
and that is the basis for essentially a classifier. 00:18:44.060 |
This is called the negative log likelihood of the positive passage. 00:18:51.940 |
you will see that it is a typical classifier loss. 00:18:55.420 |
The only possible twist is the denominator is 00:18:58.620 |
the positive passage score and then on the denominator, 00:19:02.020 |
I have the positive passage sum together with 00:19:04.380 |
all the negative passages that I have in my example set. 00:19:11.340 |
So that's why this examples look like this because that's what's being used 00:19:16.980 |
here to optimize all these parameters to score documents. 00:19:22.780 |
Final thing, I hope you're thinking about this. 00:19:26.740 |
It's going to be incredibly expressive and powerful, 00:19:31.740 |
The cost of having the query and document interact at 00:19:34.940 |
query time is that I can't process any of these documents ahead of time. 00:19:41.860 |
like your Google, you're using a cross encoder, the user queries. 00:19:45.820 |
You need to process that query together with every single document on the web, 00:19:54.660 |
But obviously, each query could take years to serve. 00:20:06.620 |
You see this sometimes where a cheap retriever gets a lot of like, 00:20:10.220 |
like a 1,000 documents and then this is done to re-rank the last 1,000. 00:20:18.540 |
Um, could you use this with multiple possible, 00:20:29.340 |
like multiple of those could be like good, but. 00:20:37.100 |
The numerator could be the sum of the positive and 00:20:39.340 |
then the denominator could just include all of those. 00:20:53.060 |
The other approach would be to just treat them as separate examples. 00:21:06.220 |
But it's worth thinking about. I'll get back to you on that. 00:21:27.700 |
but the point is, we process them separately. 00:21:30.280 |
And I've made lighter every state except the output tokens, 00:21:37.120 |
because those are the only ones that we need. 00:21:41.700 |
And then we do some kind of scoring on that basis like similarity. 00:21:48.660 |
Now, the similarity function as I'm calling it for, 00:21:52.140 |
for a query and a document is we process the query, 00:22:03.100 |
We just score based on those representations. 00:22:08.260 |
So now, we've got something that is highly scalable because we can process 00:22:13.740 |
every document in our entire web collection into a single vector, 00:22:24.000 |
and do this super fast dot product comparison for scoring. 00:22:28.620 |
So now we've got something that is probably even going to 00:22:31.380 |
function as a full ranking model, not just a re-ranker. 00:22:35.380 |
But the real game is that we can process all our documents offline. 00:22:41.500 |
almost no interactions between the query and the document. 00:22:46.920 |
if you think about the soft matching that happens with TF-IDF, 00:22:50.940 |
none of that is going to be able to happen here. 00:23:01.540 |
essentially what we'll be comparing would be two fixed length vectors, 00:23:05.020 |
right? I mean, a vector representing the document and another representing the query. 00:23:11.260 |
a limit to the length of the document that would be represented in that vector? 00:23:15.580 |
I mean, like, could it represent an arbitrary, 00:23:22.420 |
These are great questions. Let me repeat them. 00:23:26.700 |
The one constraint we need to impose on the query encoder and 00:23:31.420 |
the same dimensionality so that we can do the dot product. 00:23:34.160 |
They can otherwise be separate models if we want. 00:23:37.040 |
And then length of query and length of document, 00:23:39.840 |
that's just going to be imposed by whatever we 00:23:42.020 |
choose for the query and the document themselves. 00:23:45.860 |
you're going to be stuck with 512 as the length, 00:23:47.980 |
the longest document that you can process unless we 00:23:50.620 |
do some further manipulation of these things. Yeah. 00:23:56.700 |
If these models are trained to project into kind of like a shared embedding space, 00:24:03.060 |
so like documents that are similar to a query are going to 00:24:06.040 |
fall into a similar location in embedding space, 00:24:09.700 |
could we have a system where we essentially take, 00:24:17.840 |
project it into embedding space and then do like a new research or something like that? 00:24:24.200 |
Well, yes. So some aspects of what you're describing, 00:24:30.000 |
The other parts of what you're saying are going to be 00:24:32.440 |
optimization tricks that I show you in a second, I believe. Yes. 00:24:37.960 |
You elaborate on what you mean by limited query doc interactions? 00:24:41.880 |
Just that all we've got in the end is this vector for 00:24:45.860 |
the whole query and this vector for the whole document. 00:24:49.020 |
So token identities to the extent that they're preserved at all, 00:24:52.400 |
they have to have been packed into those vectors. 00:25:00.040 |
every token level interaction you can imagine as a result of us using like the transformer. 00:25:09.160 |
Is there any room for training some more clever synthesis of 00:25:13.120 |
the two representations you get at the end as opposed to just dot producting them? 00:25:17.760 |
I think that's a natural follow-on is that you might think, 00:25:20.660 |
I want to have in this layer here some additional parameters, 00:25:24.260 |
and you can kind of see how that might work, right? 00:25:28.260 |
I would put some parameters on top and then the same optimization can be used. 00:25:34.160 |
If they're going to the same embedding space, attention. 00:25:40.080 |
Yeah. Yeah. And this could be good in the sense that we, 00:25:44.280 |
we would pay a little bit of a cost by adding more parameters, 00:25:47.440 |
but we might gain something in terms of expressivity. 00:26:00.900 |
that I've just showed you two loss functions. 00:26:06.000 |
and you can probably already see that they are identical 00:26:08.800 |
except for this function that you might call comp here. 00:26:14.100 |
And as you think about these different model architectures, 00:26:16.980 |
probably what you're thinking about is simply changing 00:26:19.860 |
this comp function and then using your available data to 00:26:23.200 |
train the model against this negative log likelihood of the positive passage. 00:26:27.720 |
There are other losses out there in the literature, 00:26:31.160 |
but this is the most widely used and it seems to be very effective. 00:26:35.760 |
Colbert. This stands for contextualized late interaction with BERT. 00:26:44.120 |
It was invented by Omar and Matei who are here. 00:26:46.760 |
Omar is my student, I work closely with Matei. 00:26:51.040 |
So Omar would want you to know that this stands for contextualized late interaction, 00:26:56.120 |
and he pronounces it Colbert because Stephen Colbert has 00:26:59.520 |
a show full of contextualized late night interactions. 00:27:07.340 |
It's your choice because the BERT there is the BERT model. 00:27:11.680 |
And we are, yes, still hoping that Stephen Colbert will take notice of this. 00:27:33.920 |
the query encoder on the side for reasons that you'll see, 00:27:37.520 |
So imagine BERT processes my query and I've grayed out everything, 00:27:41.000 |
but the final representations because crucially, 00:27:43.280 |
those are the only ones that we actually need. 00:27:49.300 |
Now what I'm going to do with Colbert is form a grid of scores. 00:27:54.040 |
And this is going to essentially give the similarity value 00:27:57.640 |
between every query token and every document token. 00:28:02.560 |
And then I will choose the values along the rows, 00:28:08.480 |
the document token that maximizes that similarity comparison. 00:28:12.880 |
And the scoring function is essentially the sum of those three max values. 00:28:18.800 |
That is why you see maxim all over the place for Colbert. 00:28:28.240 |
and I wrote down here what we would think of as the comp function, 00:28:40.760 |
So you can see why it's contextualized late interaction, 00:28:46.520 |
But unlike DPR, I'm allowing them all to interact with 00:28:50.360 |
each other via these very fast maxim calculations. 00:28:56.760 |
It's right, so highly scalable and highly expressive. 00:29:00.440 |
The only cost is that the interactions happen only in this very thin final layer. 00:29:10.120 |
that Colbert because of this brings us back to common intuitions in IR. 00:29:14.960 |
We do genuinely achieve with these maxim scores intuitive soft alignments. 00:29:20.120 |
Here I have the query, when did the Transformers cartoon series come out? 00:29:25.680 |
the animated Transformers was released in August 1986. 00:29:29.480 |
And these are proportional to actual maxim values query relative to document. 00:29:37.160 |
And you can see that it is doing something very intuitive, 00:29:44.240 |
I don't have to do anything special to capture the fact that 00:29:47.600 |
come in the context of come out is a lot like released. 00:29:55.000 |
Here I'm showing the two topmost maxim values, 00:30:00.040 |
And this is wonderful because IR has been so successful for so long, 00:30:04.200 |
doing term matching, and it is nice to see that intuition 00:30:07.920 |
carried forward into this more semantic space in my view. 00:30:14.880 |
Is that matrix of like query to document mapping, 00:30:19.760 |
is that why the in-memory index is very purple there? 00:30:27.840 |
It is because we have to store every token level representation. 00:30:32.080 |
Yes. I'm gonna, I'm gonna show you that we can do better, 00:30:36.960 |
our entire document store is gonna be a lot of vectors. 00:30:40.600 |
One per token, not one per type, one per token. 00:30:53.240 |
Maybe there's some intuition that can make this a little clearer. 00:31:04.360 |
they, they're all related to the original token transformers. 00:31:07.840 |
Would you be able to kind of draw that link to all those other tokens as well, 00:31:23.960 |
It's amusingly ambiguous between our model and the animated cartoon series. 00:31:32.240 |
So, but I think the point is that because it's BERT, 00:31:39.320 |
a very different representation from the one if we were talking about NLP. 00:31:43.200 |
And that's why it's so good that we are using BERT because then we'll 00:31:46.280 |
get Maxims that are appropriately semantic. That is the hope. 00:31:50.600 |
Yeah. Whereas term-based models really gonna struggle. 00:31:55.840 |
engrams that kind of capture the fact that this transformers is the cartoon series one. 00:32:01.840 |
So another, actually, that's another argument in favor of being in a more semantic space. 00:32:08.760 |
I want to just quickly talk with you about how we have worked to optimize Colbert, 00:32:15.880 |
because I think that this suggests things that you would want to do if you 00:32:18.600 |
developed your own neural retrieval models because the hard truth here is that 00:32:23.080 |
BM25 is blazingly fast and scalable and these neural models are not. 00:32:29.800 |
You have to work much harder to get them to that point of being as performant 00:32:33.240 |
in terms of other dimensions beyond accuracy. 00:32:36.660 |
We could use Colbert as a re-ranker as I alluded to before, right? 00:32:40.960 |
So here I have all these token level representations which I do have to store, 00:32:55.920 |
that's represented as a sequence of tokens, uh, 00:32:59.080 |
I could get the top k documents for it using like BM25 and then re-rank that top k. 00:33:08.560 |
I pay the full price of the Colbert model but only for k documents. 00:33:12.840 |
And you're hoping that BM25 did a good job of getting you to that initial point. 00:33:18.240 |
It's a very common application and it can be really meaningful to 00:33:29.960 |
We again store all those token level vectors, 00:33:33.020 |
but now we're gonna kind of turn things around. 00:33:35.160 |
We just need to keep track of those vectors and their associated documents. 00:33:39.820 |
For a query that we have encoded as a set of vectors using Colbert, 00:33:47.940 |
and we retrieve the p most token vectors from 00:33:50.960 |
this huge list that are similar to our target. 00:33:54.360 |
And that doesn't require the full Colbert model. 00:33:56.840 |
That could just be a similarity calculation and you can do those really fast. 00:34:03.200 |
And then you get all the documents associated 00:34:06.120 |
with this small set of vectors that you found and you score them. 00:34:09.640 |
So again, the name of the game is to use Colbert only very 00:34:13.680 |
sparingly in a final stage because it is so expensive. 00:34:21.240 |
just quickly, we can do even better and this is quite striking. 00:34:24.860 |
What we can do is cluster our token level representations 00:34:28.940 |
into their centroids using k-means clustering. 00:34:36.860 |
So again, we encode our query into a series of vectors. 00:34:43.060 |
we first get the centroids that are closest to that. 00:34:48.620 |
we can collect only like four centroids per token vector and do really well. 00:34:56.420 |
Then we get the t most similar token vectors to that centroid. 00:35:01.580 |
And then we finally do scoring on the associated documents. 00:35:08.980 |
we have reduced the amount of compute we need to do with this huge index 00:35:13.300 |
by using the centroids and then using Colbert again very sparingly. 00:35:18.660 |
Final thing and then I'll take some questions. 00:35:22.020 |
The team has worked very hard to reduce the latency of Colbert. 00:35:27.900 |
And the thing I want to point out to you is that the Colbert model steps, 00:35:32.640 |
the one I just described to you with the centroids, 00:35:37.260 |
the overall cost because it was being used so sparingly. 00:35:40.380 |
The big costs were being in the- dealing with 00:35:44.460 |
the huge index and also doing work to quantize the vectors, 00:35:49.940 |
so that they were easier to store on disk by making them smaller. 00:35:53.520 |
And so after a bunch of work with this framework called Plaid, 00:35:57.580 |
they were able to get rid of almost all of that index lookup and 00:36:01.380 |
de-quantization or decompression steps for the vectors that were costing so much. 00:36:05.900 |
And they brought the latency down to like 58 milliseconds. 00:36:10.620 |
Which- so it went from something that is impossible to imagine deploying 00:36:17.820 |
what you might entertain as a possibility for deployment. 00:36:21.100 |
And I- you know, the details are in the Plaid paper. 00:36:25.500 |
I just wanted to call out that I think this is an incredible achievement. 00:36:30.980 |
the set of things that they did to achieve this enormous improvement. 00:36:36.580 |
And it does mean that if you had heard a rumor that Colbert was 00:36:40.580 |
impractical to use because the index was too large and the latency was too long, 00:36:46.500 |
The indices are small because of quantization, 00:36:54.100 |
I have one more model, but let me take questions. 00:36:58.380 |
Oh, sorry. I just had a question about the latency and also the predictor. 00:37:03.860 |
The Plaid paper is full of tricks and things like that. 00:37:08.580 |
I definitely want to give Sid plenty of time to talk about models. 00:37:18.740 |
I wrote sequence at the bottom because we're going to do this for 00:37:32.980 |
Okay. So we again process the text into the output states, 00:37:44.020 |
And the scores are determined by this thing here, that's SI sub J. 00:37:48.420 |
So we're going to apply a linear layer to the encoding, 00:37:51.900 |
those output states, and we're going to combine it with 00:37:55.500 |
the embedding for these vocabulary items with a bias. 00:38:01.060 |
you can see that this is like a dot product of 00:38:03.540 |
these states with all of these values here in our vocabulary. 00:38:12.060 |
And so you can think of that as summing across all the document tokens. 00:38:15.260 |
And so what we've got in that orange column there is a probably very sparse vector 00:38:22.140 |
that represents this text down here with respect to our vocabulary. 00:38:29.260 |
So this is a lot like term-based, uh, work of old, right? 00:38:43.420 |
So we should get the advantages of being with a semantic model. 00:38:48.180 |
And then the similarity value is just the SPLADE representation, 00:38:52.580 |
that is this representation here for the query dot product with the document. 00:38:57.820 |
And the loss is the one that we've been using all along. 00:39:02.100 |
So just to be clear, so you do the SPLADE process both with the query and with the document. 00:39:12.940 |
Yeah, that's it. There's a bunch of- it looks similar in my document. 00:39:15.980 |
This is great. Let me review what you just said. 00:39:18.860 |
Sequence, not query or document because we do this for both kinds. 00:39:22.900 |
And of course, we can do all the documents ahead of time. 00:39:26.180 |
The big twist is that we're scoring these sequences with respect to the vocabulary. 00:39:31.140 |
And we are essentially getting in semantic space because this is an embedding space here, 00:39:36.020 |
and this is a contextual embedding space here, 00:39:38.500 |
scores for each query term with respect to the whole vocabulary. 00:39:47.980 |
and their optimization further encourages sparsity. 00:39:51.300 |
And then the similarity value is the dot product of those for queries and for documents. 00:39:56.580 |
So it has some hallmarks of late interaction, 00:39:59.860 |
except it is interacting the text representations with the vocabulary, 00:40:12.020 |
You saw it in some of my doc- of my slides before, very impressive. 00:40:16.860 |
And it's also a new way of thinking, which I really like. 00:40:24.780 |
and one theme of them, I won't go through them, 00:40:26.620 |
is just that people are working hard finally on making these models more efficient. 00:40:31.780 |
So a big theme of this is not just obsession with accuracy, 00:40:39.620 |
And then finally, for that paper that I mentioned before, 00:40:43.300 |
we did a bunch of systematic investigations of different approaches. 00:40:54.580 |
these models where people have worked hard to optimize them. 00:40:57.580 |
There's lots of tables like this in the paper. 00:41:02.080 |
BM25 is the only solution that could run on this tiny hardware here. 00:41:09.900 |
That's why it's alone in its own little block there. 00:41:30.900 |
except one of them is double the latency of the other one for this hardware. 00:41:36.900 |
And so you might wonder, do I really need this extra point of performance? 00:41:50.180 |
but its latency is a quarter or something of the Colbert v2 small. 00:41:57.780 |
So maybe you care more about that and not so much about the success. 00:42:01.380 |
And then if you compare these two splayed, right, 00:42:05.540 |
But if you just jack up the hardware a little bit, 00:42:16.660 |
It went up for all of them with this heavy duty hardware. 00:42:22.500 |
So this is the space that you're actually operating in. 00:42:25.020 |
I'll- we'll talk later about how we might more systematically integrate all these scores. 00:42:29.940 |
I think this is enough now to get you thinking about all of these dimensions. 00:42:40.020 |
we show that you never need a GPU for Colbert, I believe. 00:42:42.900 |
You just- so you can always use cheaper hardware. 00:42:52.180 |
The final section of this is just some datasets. 00:42:54.220 |
I think I don't need to go through it because you have it as a resource. 00:42:56.660 |
If you want to get started in neural information retrieval, 00:43:02.140 |
and then there are a bunch of new benchmarks that are 00:43:04.420 |
designed to assess systems out of the box, that is zero-shot. 00:43:07.540 |
Beer is great. Latte is great for long-tailed, 00:43:11.260 |
topic stratified evaluation, and then this XOR tie-dye is cool because this is multilingual. 00:43:17.140 |
And I know you have expressed interest in multilingual stuff. 00:43:20.060 |
This could be a great playground for doing that with kind of QA and retrieval, 00:43:35.620 |
NLU and IR are back together again after being apart for so long, 00:43:40.340 |
and this is having profound implications for research and technology development. 00:43:44.860 |
So this is absolutely a very exciting moment to participate in this research because there is 00:43:50.420 |
so much innovation yet to happen and it is having 00:43:54.060 |
such an impact on research and also out in the wider world. 00:44:03.260 |
So it's cool. It's like retrieval isn't just hitting NLU, 00:44:06.380 |
it's hitting everywhere, like vision and robotics as of like, 00:44:10.260 |
this week we're starting to use retrieval methods to do. 00:44:13.660 |
What's the best way to figure out how to do a new task? 00:44:16.820 |
Maybe retrieve some examples of a robot or a human doing 00:44:20.460 |
the same task and then generating your actions. 00:44:30.580 |
Yeah. All right. So I'm going to kind of pick up or try to pick up where I left off 00:44:39.060 |
last week and kind of give you this evolution, 00:44:41.820 |
this history lesson on how we got to the transformer, 00:44:45.420 |
and then go from there into tips and tricks for training big models generally, 00:44:50.620 |
and then end with like a small little teaser on 00:44:56.060 |
So you can use that in your projects down the road. 00:45:07.740 |
the transformer paper came out on both the RNN and the CNN side, 00:45:11.340 |
and tied a lot of the innovation around the transformer 00:45:22.300 |
and the connections there were closer than the connections to RNNs. 00:45:26.300 |
Kind of walk through how we got to the self-attention block with this fancy code, 00:45:33.580 |
splitting your heads and you can kind of think of your heads in 00:45:36.300 |
a self-attention block as the different kind of kernels or filters in a CNN layer. 00:45:41.460 |
Then kind of closing with like this full self-attention block, 00:45:45.860 |
where we're actually doing the RNN style attention, 00:45:48.500 |
and then this question of this non-linearity that we're adding at the end. 00:45:54.780 |
this sort of MLP that we're adding to the end of each transformer block, 00:45:58.180 |
we're really just doing weighted averages of linear transforms of values. 00:46:02.820 |
Okay. So, if we kind of take this as ground truth, 00:46:09.060 |
starting point for what a transformer block looks like, 00:46:12.140 |
very much inspired by the ideas of CNNs and RNNs with attention at the time. 00:46:23.380 |
again as we stack more and more layers together. 00:46:25.860 |
There's a problem. Can anyone spot the problem in this implementation by itself? 00:46:34.700 |
We keep adding the same input over and over again as we go deeper. 00:46:38.980 |
Eventually, specifically in the RNN attention layer, 00:46:43.700 |
when we take this dot product between the queries and the keys, 00:46:53.340 |
kind of building the transformer layer is very, 00:46:56.680 |
the second part is just trying to make sure it doesn't fail, 00:46:59.620 |
and doesn't blow up, and doesn't crash when we try training it. 00:47:02.300 |
So, what's one thing that we can do to kind of avoid 00:47:11.260 |
So, layer normalization, maybe batch norm and layer norm were covered earlier on, 00:47:19.540 |
we're just going to normalize so that each feature has mean zero, 00:47:23.340 |
standard deviation one, which means that every time we add a residual connection, 00:47:28.300 |
we're going to normalize so that everything comes back to a decent space. 00:47:31.300 |
We're still able to learn the kind of same level of expressivity we care about. 00:47:37.900 |
keep blowing up or growing the magnitude of our activations. 00:47:45.860 |
NN dot layer norm with the dimensionality of our transformer, 00:47:51.940 |
We're just going to normalize each X before we pass it 00:47:54.300 |
into the attention and the MLP layers respectively. 00:47:57.420 |
Now, there's a problem with this that isn't obvious, 00:48:00.860 |
and actually wasn't obvious to the people building transformers at the time. 00:48:03.580 |
It wasn't really explained kind of till three years later, 00:48:06.900 |
which is that you have optimization issues when you do this. 00:48:14.300 |
the naive transformer with this layer norm in place, 00:48:19.500 |
which is like learning rate decay or a constant learning rate, 00:48:25.740 |
Specifically, I'm going to use the hugging face emojis, 00:48:38.540 |
they would tell you one or the other is happening, 00:48:40.740 |
but there are definitely no gradients that are like 00:49:00.540 |
Now, this is actually just like fun because I have the time. 00:49:07.980 |
>> I'm thinking, like I think I remember in the [inaudible] paper, 00:49:21.100 |
they had like a weird learning rate, but I don't remember. 00:49:23.380 |
>> So, it is in the original transformer papers, 00:49:26.300 |
like the main thing that they get to get this stable. 00:49:28.980 |
So, it's one of the authors that came up with it. 00:49:33.580 |
on the first ever transformer code base from Google, 00:49:39.100 |
in like the R parse like flags for the different optimizers you use, 00:49:44.340 |
there's one option just called Gnome, after Gnome Shazier. 00:49:54.700 |
And it's called the Gnome Optimizer for a really long time, 00:49:57.620 |
until they just decided to call it just like, 00:50:02.460 |
And so, Gnome Shazier kind of came up with it. 00:50:05.420 |
And if you were to kind of go back and think about like 00:50:08.340 |
the sorts of problems and the papers he was working on at the time, 00:50:13.820 |
the different types of like gradient descent optimizers, 00:50:19.660 |
like RMSProp, Adafactor came out like a year after the transformer paper came out. 00:50:25.060 |
And he was really like interested in like looking at this problem of like, 00:50:28.060 |
"Huh, okay, weights seem to be where like if you just 00:50:30.900 |
really inspect the gradients early on with like this layer norm thing, 00:50:34.340 |
variance seems to be high and you kind of want to burn that in." 00:50:38.900 |
so he's kind of doing this already for his LSTM work. 00:50:48.620 |
It breaks conventional machine learning wisdom, 00:50:50.940 |
like why am I warming up my learning rate before I'm bringing it down, right? 00:50:57.660 |
move in, and then like maybe anneal it as I get closer to my minimum. 00:51:04.900 |
till three years later, a paper comes out that kind of like steps through 00:51:09.580 |
the specifics of training a transformer model on some data, 00:51:16.260 |
and actually tie it to the layer normalization layers that we just added. 00:51:21.860 |
Right. So up top, we have kind of good gradients. 00:51:25.180 |
Right. So on the left here is the gradient magnitude and here's the update magnitude. 00:51:30.820 |
the updates that are actually applied to the weights. 00:51:34.220 |
in red, we have the same thing but without warm up. 00:51:37.420 |
And what ends up happening is that gradients go to zero somehow as you train. 00:51:41.820 |
It's actually a weird graph because like as you're coming forward in time, 00:51:47.180 |
So this is kind of like starting out and then like this is kind of, yeah. 00:51:54.420 |
you know, wherever training becomes unstable. 00:51:57.380 |
And then your updates also become super high variance. 00:52:00.860 |
So they do some math and they kind of bound the update as a kind of 00:52:07.020 |
like proportional to the dimensionality of the, 00:52:11.540 |
or the square root of the dimensionality of your transformer, 00:52:17.460 |
like if the size of your activation is like sufficiently large, 00:52:21.940 |
your layer norm gradient is going to be completely, completely screwed. 00:52:31.780 |
so warm up is necessary because it helps you get 00:52:35.780 |
the atom optimizer to kind of like move slowly enough at the beginning. 00:52:40.900 |
So that we're kind of like saturating the gradients, we're like, okay. 00:52:48.780 |
The activations norms aren't changing all too much. 00:52:53.740 |
They're changing in a predictable way and we can kind of start to handle that, 00:52:56.540 |
and then conventional ML kicks in. But it's weird. 00:53:00.340 |
And it's also weird that it took three years later, 00:53:02.340 |
and some people still don't buy this explanation, 00:53:04.340 |
but it's the best explanation I've got to why we need that warm up. 00:53:09.700 |
you're fine tuning or pre-training a transformer, 00:53:11.980 |
warm up for at least five percent of your full training, 00:53:20.660 |
So is there some data dependency or some assumption about what the data will be like? 00:53:40.700 |
>> Yeah. So I think in this paper they're looking at what I'll call nice datasets. 00:53:45.300 |
They're looking at the Wikitext 2s of the world that are somewhat predictable. 00:53:59.060 |
Just really, really unpredictable things that are low likelihood under your model, 00:54:03.660 |
that are going to cause big updates in the middle of 00:54:05.100 |
training that are going to completely crash your run. 00:54:07.380 |
So this happened tons of times while we were training 00:54:13.900 |
any big model in like the one million plus parameter range. 00:54:24.020 |
the notes in like the GitHub repository for like training a T5, 00:54:31.500 |
re-randomize your data order and then try again, 00:54:38.420 |
most modern language models are trained right now. 00:54:43.740 |
We think it's tied to the data, we can't isolate it. 00:54:49.340 |
Eventually, you'll just keep making progress. 00:55:03.660 |
So up top is blue with the traditional transformer learning rate, 00:55:14.420 |
let's just start the learning rate high and then taper. 00:55:26.660 |
>> Yeah. So the way you can interpret this graph, 00:55:29.020 |
and the paper's linked in at the bottom of the slide, 00:55:31.980 |
but you can think of the furthest back magnitude is like this, 00:55:36.820 |
basically plotting the mean standard deviation of the updates across layers. 00:55:47.300 |
you get to batch 100 or batch 400 or whatever. 00:55:54.580 |
>> I wonder to what extent the warm up and the rate. 00:56:00.060 |
>> I think I'm relating to the choice of optimizer. 00:56:05.540 |
Because I've run into some problems where I found that using 00:56:11.060 |
AtomX of the infinity norm will work because I've got this level of dropout or whatever. 00:56:15.740 |
Is there any guidance of all these big hyperparameters that go into this tune, 00:56:23.580 |
I should be pushing one down or choose this optimizer, 00:56:31.580 |
it feels a little bit like the Wild West, which is what it is. 00:56:34.820 |
>> Yeah. So if I were to paraphrase your question, 00:56:38.540 |
it's like, if you decide to change anything about the current recipe, 00:56:42.380 |
like change your optimizer, change dropout, change your learning rate, 00:56:45.220 |
are there rules of thumb for what else you need to change to get things to work? 00:56:54.820 |
So part of why I led with how we got to here, 00:57:05.620 |
Because it's still at the point where optimizing these models, 00:57:10.660 |
is still concentrated in the minds and experience of a very small number of people. 00:57:16.820 |
Because who's trained a seven billion parameter language model, 00:57:19.460 |
or who's trained a 100 billion parameter language model? 00:57:23.340 |
When you're talking about a training run that cost millions of dollars to develop, 00:57:31.340 |
how many things are you really going to be trying at the end? 00:57:38.740 |
like scaling laws research where they're trying 00:57:40.500 |
these different things in some bounded search space. 00:57:43.420 |
But if you were to invent like a brand new optimizer, 00:57:47.780 |
third, fourth order moments of your gradients, 00:57:51.300 |
maybe do something fancy relative to how things are changing over time. 00:57:55.860 |
And you were to just try and apply it to the biggest language model you could train, 00:58:00.740 |
I have no idea what I would tell you in terms of like what things you should change, 00:58:06.220 |
like maybe don't set the learning rate to be ridiculously high. 00:58:16.380 |
let's just say during training you come across a bad batch of data that happens to 00:58:20.300 |
cause like the destabilization with the gradients. 00:58:23.700 |
And then you rewind back to your checkpoint and you take that same exact batch of data, 00:58:28.020 |
but instead of running it through training when gradients are enabled, 00:58:37.500 |
anomalous behavior from the model during inference, 00:58:41.060 |
or is it strictly just a back propagation issue? 00:58:47.140 |
so we a couple of years ago trained like some, 00:58:59.140 |
The way we debugged it was like looking at the gradients, 00:59:02.140 |
which didn't tell us much, but then we just looked at activation norms per layer, 00:59:06.140 |
and that's how we actually debug this, right? 00:59:12.420 |
like where we thought we could possibly be overflowing or underflowing, 00:59:23.540 |
Because a single batch isn't going to perturb everything. 00:59:29.980 |
eventually you're going to fall under some trajectory where things get bad, 00:59:34.900 |
So we would be able to deterministically be running to that checkpoint, 00:59:38.060 |
then deterministically step through training and log every activation, 00:59:42.540 |
which is expensive, but that's how we were able to get to the bottom of the problem. 00:59:46.060 |
But I don't think we have tools for actually figuring out which batch of data was, 00:59:51.100 |
or which sequence of batches of data were the actual triggers for that behavior. 00:59:55.860 |
>> I guess I was just curious specifically about if the same data that causes 01:00:00.020 |
destabilization in training can cause anomalous behavior, 01:00:09.380 |
There's some more recent work about how to quantize these models, 01:00:12.860 |
like how to get a transformer that's trained with like 16-bit precision to 01:00:17.060 |
like train or to run with like eight bit precision by like intelligently like 01:00:21.740 |
bucketing floats from Tim Detmers who's a PhD student up at UW. 01:00:27.140 |
He has this theory on something called outlier features that show up in 01:00:30.860 |
these really big models that kind of try and get at this, 01:00:36.980 |
Yeah. Okay. So are we done now that we fix this like layer norm stuff, 01:00:45.660 |
all of the stuff to get the transformer to work? 01:00:47.660 |
Kind of. Right. So like over the last few years, 01:00:51.420 |
people want to like milk the most of the transformer, 01:00:59.900 |
So one, when you're training and potentially and you're projecting to queries and 01:01:07.220 |
the bias term in each like linear layers like WX plus B, 01:01:10.580 |
you can get rid of the Bs because like it's not really doing 01:01:12.540 |
anything and it saves a little bit of compute. 01:01:14.460 |
So let's get rid of them. Like that's like the first thing to throw out. 01:01:17.700 |
There are different activations that have been invented and different types 01:01:21.020 |
of like cool ways to just better fit your data. 01:01:29.820 |
defines like a separate weight matrix as part of the activation. 01:01:33.940 |
And then a swish is like a sigmoid activation 01:01:39.740 |
that applies to one part of the weight and not the other. 01:01:44.460 |
This is actually the activation of choice now in most transformer implementations. 01:01:51.020 |
Palm was trained with this, works really well. 01:01:55.060 |
One thing that folks noticed is that moving the layer norm to happen, 01:02:00.860 |
you know, before you actually feed it through the attention or the MLP layers instead of 01:02:07.340 |
after is a more stabilizing force is actually kind of important. 01:02:16.580 |
you don't actually need these trainable parameters for mean and variance. 01:02:20.980 |
You can actually just like divide by the mean square 01:02:23.500 |
or the RMS of like your tire activation feature. 01:02:28.940 |
irrelevant flops because we're training massive models and we're 01:02:31.700 |
trying to do the bigger model on the compute that we have. 01:02:35.900 |
Oh, yeah, here's the code for like a swish glue activation and an RMS norm. 01:02:44.380 |
a projection layer is basically saying like let's take 01:02:46.580 |
this input feature projected into like two separate chunks. 01:02:50.260 |
One chunk becomes like a gating value kind of like 01:02:53.260 |
in a gated recurrent unit in like the RNN literature. 01:02:59.540 |
you apply the sigmoid to the gate and then multiply it element-wise with 01:03:03.100 |
the value and you get your new thing, works really well. 01:03:06.620 |
An RMS norm is like literally just dividing by the norm 01:03:10.180 |
of the vector instead of like trying to learn anything fancy. 01:03:13.940 |
Cool. This is what the modern transform looks like. 01:03:17.820 |
So that's it for the evolution of the transformer. 01:03:23.420 |
nothing in the last two weeks have like changed drastically from this. 01:03:29.900 |
let's say we are doing a fine-tune instead of like the full train, 01:03:36.660 |
How would we still want to follow these kinds of guidelines or these specific to 01:03:41.220 |
just doing all of the data doing a full pre-train? 01:03:46.860 |
what of this do we really need if we're fine-tuning, 01:03:50.420 |
or if we're kind of doing like parameter efficient fine-tuning, 01:03:52.460 |
like is this only necessary for pre-training? 01:03:54.380 |
So I've started using the Swish Glue instead of like any other activation like everywhere, 01:04:07.340 |
Everything else you can probably not care about. 01:04:10.540 |
The RMS norm, the pre-norm is probably just like a general rule of thumb if you're adding 01:04:16.180 |
any transformer layers just because it is like demonstrably more stable, 01:04:22.900 |
Other questions here before we move to how to train on lots and lots of compute. 01:04:44.740 |
but I have seen a few things as far as language models have gone. 01:04:50.340 |
So like 2018 is when I think did my first deep learning tutorial. 01:04:55.940 |
I trained a MNIST like the typical like two-layer, 01:05:04.940 |
It's the 100,000 parameter line. That's 2018. 01:05:21.460 |
2020, I kind of branch out from like the small NLP stuff I'm doing to like 01:05:26.100 |
more intensive NLP so looking at tasks like summarization, 01:05:30.420 |
training models with like 10 million parameters. 01:05:33.580 |
Then by 2021, like the biggest models I was training, 01:05:39.580 |
was like when I switched into multimodality robotics, 01:05:41.660 |
looking at visual question answering, 18 million parameters. 01:05:47.660 |
and I think this was the standard pipeline for a lot of 01:05:52.020 |
is like I'd be able to train most of my things on 01:05:55.140 |
one GPU or even my laptop CPU for like a maximum of a few hours. 01:06:03.620 |
at least for like most of the things I was doing on a day-to-day. 01:06:15.260 |
Let's at least figure out if we can get an academic lab to 01:06:18.140 |
like try and train a GPT-2 like the earlier generation." 01:06:21.420 |
So clocking in at like 124 million parameters, 01:06:26.700 |
magnitude bigger than anything I trained at the time. 01:06:28.860 |
So why I decided to do this is still beyond me, 01:06:33.300 |
One of the useful things that I learned is that training 01:06:40.940 |
a decent GPU that we had access to at the time, 01:06:44.300 |
with a batch size greater than four would go out of memory, 01:06:47.420 |
which is bad because a batch size of four is small, 01:06:50.340 |
and we wanted to ideally train with like a batch size of 512. 01:06:57.980 |
which is like I'm going to run batch sizes of 01:07:00.140 |
four however many times it takes to get into 512, 01:07:05.740 |
processing all of those batches sequentially. 01:07:07.860 |
So I'm just going to keep accumulating the gradients, 01:07:16.860 |
it's like 100 days to train on that single GPU for 400,000 steps. 01:07:20.900 |
So how do we go from this clock of 100 days to something reasonable? 01:07:37.300 |
ended up looking like three different parts across 01:07:48.900 |
and the NLP group decided to invest upfront in 01:07:53.660 |
so we could actually train on like 16 GPUs at once. 01:08:02.260 |
just rent on an hourly basis is 56 bucks an hour now. 01:08:06.580 |
Fifty six bucks an hour if you want to just like sit on them. 01:08:11.460 |
But if you're willing to like let anyone who has 01:08:15.180 |
the money and like wants to sit on them like preempt you, 01:08:19.980 |
So like across four days like that's not the worst. 01:08:27.340 |
So the scaling toolbox we ended up looking at was data parallelism. 01:08:32.420 |
You can think about this as like literally just divide and conquer. 01:08:35.220 |
How do I just parallelize work across all of these GPUs instead of one? 01:08:40.860 |
Mixed precision training, and we're going to talk a little bit about what that means. 01:08:46.820 |
Then this interesting idea called zero redundancy, 01:08:49.620 |
which is about minimizing the memory footprint of training. 01:08:52.820 |
Then later on as you want to scale up to hundreds of billions of 01:09:05.340 |
We'll talk like there are things that come in handy like model parallelism. 01:09:09.700 |
There are things to consider like hardware and software limitations. 01:09:14.700 |
But some of you might be here looking at me which is like, okay, 01:09:18.740 |
do I need any of this stuff if I'm not training really big models? 01:09:23.620 |
A lot of these tips and tricks like you may not have access to 100 GPUs or 01:09:31.020 |
even eight, but you might have access to two or four, comes in handy. 01:09:35.460 |
A lot of the ideas here are still ideas that I'm using when I'm training stuff on 01:09:40.100 |
my laptop or when I'm trying to run inference with the latest big model 01:09:45.140 |
But please ask questions if things become too hazy or too not useful. 01:10:01.780 |
>> So Colab, yeah, Colab, you're still limited to a single GPU. 01:10:04.980 |
>> And I'm guessing zero redundancy might help. 01:10:08.940 |
>> So mixed precision would help, kind of, definitely for running inference. 01:10:15.660 |
And zero redundancy would also help running inference. 01:10:20.380 |
>> So zero redundancy has an add-on that they wrote up in a paper later called 01:10:24.940 |
zero infinity, which is like, what if I didn't put all my weights on the GPU at 01:10:28.500 |
once, what if I put some of them in CPU RAM or even in NVMe SSD storage? 01:10:35.100 |
So actually turning your laptop into a more powerful workhorse than a Colab GPU. 01:10:39.460 |
Cool, so this is a toy example kind of going through data parallelism. 01:10:47.340 |
We're running low on time-ish, so this is MNIST with an MLP. 01:11:01.580 |
I'm going to define a batch size, a data loader that's going to load from 01:11:06.020 |
And then I'm just going to run lots and lots of gradient steps. 01:11:10.420 |
The idea here is, how do we parallelize this across multiple workers, 01:11:19.140 |
Well, that batch size you see there is totally divisible, especially given that 01:11:24.420 |
what we're doing at the end when we kind of compute the loss is just take an average. 01:11:36.460 |
So just like in CPU land where you can kind of think about SIMD instructions, 01:11:42.860 |
Right, so this is kind of how most graphics and 01:11:48.260 |
We're going to now think about the SPMD paradigm. 01:11:51.220 |
I'm going to write one program, and it's just going to automatically scale to 01:11:56.900 |
because we're going to split the data across multiple machines. 01:12:03.140 |
a lot of the hard parts are taken care of for you. 01:12:06.220 |
These are the only lines you need to change in the implementation. 01:12:11.220 |
So the first thing we're going to do is we're going to just create something 01:12:14.780 |
called a distributed sampler, which is going to automatically partition our data 01:12:17.980 |
across the number of workers we define up front. 01:12:20.580 |
Right, we're defining a world size of eight, so 01:12:25.500 |
So this is going to partition our data into eight different subsets that each 01:12:29.540 |
We're going to wrap our nn.module with this nice little wrapper, 01:12:36.980 |
which is going to sync the gradients for us behind the scenes. 01:12:40.340 |
And then we're going to run this with a special command called Cortron, 01:12:45.940 |
which is just going to inject a bunch of environment variables so 01:12:48.660 |
that we can get some statistics about our local rank, 01:12:53.020 |
who's the guy who should be printing stuff to the screen, 01:12:55.300 |
who's the guy who should be logging stuff, where each worker lives. 01:13:00.060 |
And that's about it, and you can do all of this, and 01:13:07.860 |
You get not quite a 16x speedup, because there is some overhead from communication. 01:13:16.140 |
It was not good enough, because we were trying to train lots of models 01:13:19.620 |
reproducibly, five seeds for like ten different model types, so like 50 models. 01:13:26.060 |
So we needed to go a little faster than this. 01:13:30.900 |
When I am training any model with an atom optimizer, 01:13:35.540 |
how much memory does just storing that model and the optimizer weights take up? 01:13:43.980 |
So in 32-bit precision, our model's going to have parameters, 01:13:50.260 |
where each parameter is stored with 32 bits, that's a float. 01:13:56.380 |
Now your optimizer is also going to do this weird thing where it's going to have 01:13:58.380 |
a copy, its own separate copy of the parameters, 01:14:00.540 |
like kind of duplicating a little bit of work there. 01:14:04.660 |
And then atom tracks momentum and variance, like the first and 01:14:10.940 |
So that's another 64 bytes, or bits, right there. 01:14:14.580 |
So the lower bound on static memory, just like storing this stuff on a GPU, 01:14:19.300 |
is 20 bytes times the number of parameters that you have. 01:14:21.620 |
This doesn't include activations at all for these larger transform models. 01:14:26.980 |
If I want to keep around every buffer, like every intermediate matrix 01:14:31.380 |
as I pass it through the network, that takes up way more space. 01:14:34.420 |
But this at least gives us something that we can reason about. 01:14:38.700 |
The training implications of this is that if I want to fit a model with 01:14:42.260 |
a billion parameters, that's going to take about 18 gigs resting, 01:14:49.900 |
Which is problematic, because most GPUs then cap out at 24 gigs. 01:14:56.100 |
The really expensive ones now have like 40 or 80, but this is still bad. 01:14:59.740 |
175 billion parameters would take three terabytes of RAM, 01:15:08.180 |
With activations, it's probably looking like ten terabytes. 01:15:15.540 |
Numbers not in bold are just putting it on the thing. 01:15:25.300 |
it was a standard defined in this IEEE document. 01:15:29.020 |
You have a one bit sign, eight bit exponent, 23 bit scientific notation, 01:15:34.620 |
like all the stuff that happens after the exponent. 01:15:47.220 |
If I'm training a model in mixed precision, what that means is that I'm just 01:15:50.700 |
going to run everything in a forward pass and 01:15:53.900 |
part of the backwards pass in 16 bit precision instead of 32 bit precision. 01:15:59.940 |
Notably, what that means is now I'm storing my parameters in 16 bits, 01:16:05.660 |
All of those intermediate activations that take up lots and 01:16:08.820 |
lots and lots of memory, especially as you go bigger, are halved, which is great. 01:16:12.540 |
But the weird part about mixed precision is not everything is mixed precision. 01:16:17.900 |
So your optimizer to stably update your model still needs the 32 bit 01:16:21.700 |
parameter copies, your 32 bit momentum, 32 bit variance. 01:16:25.660 |
But you've dropped four bytes, and those four bytes are kind of useful. 01:16:31.260 |
Yet training with mixed precision, at least a couple years ago, and 01:16:34.940 |
it's still mostly true now, is still way faster than training with full precision. 01:16:45.580 |
starting with the Volta cards, started shipping with these things called tensor cores. 01:16:51.860 |
Tensor cores are basically the individual logical units on a GPU that are responsible 01:16:58.660 |
Your GPU is really good at accelerating neural network training because it's 01:17:04.020 |
These things are optimized for 4x4 to 16x16 size shards. 01:17:13.460 |
you can actually end up using way more tensor cores than if you were using 01:17:18.460 |
And so you're able to get a ton of speed ups just because you're able to tap 01:17:23.420 |
into the underlying hardware of your system more frequently. 01:17:26.460 |
As of the Ampere style cards, like the A100s or 01:17:36.140 |
Those start shipping with these cores that are able to do float 32 precision, but 01:17:43.940 |
So when you can, train with 16 bit precision. 01:17:46.740 |
All right, this shaves a day off of our small scale training. 01:17:54.020 |
This shaves off way more, especially as you go bigger. 01:17:58.700 |
And now the final bit is how do we eliminate the redundancies? 01:18:07.020 |
So why do you need the 32 bit precision for the optimizer, but 01:18:11.140 |
>> So when you are estimating the gradients, precision matters more. 01:18:18.300 |
Specifically, you want those 23 bits that kind of correspond to 01:18:27.460 |
Because while the full range of float 32 is really, really big, 01:18:32.380 |
it actually can't be super precise in the 0, 1 range, for example. 01:18:38.660 |
So you kind of want as much there as possible to kind of ensure precision. 01:18:44.140 |
Okay, so zero redundancy, standard data parallelism. 01:18:48.940 |
You're basically storing everything on each GPU. 01:18:51.780 |
Key idea is I don't need to store everything on every GPU. 01:18:56.380 |
I just need to store some things on every GPU. 01:18:59.100 |
So the model gets to stay on each GPU cuz the model has to do all the work. 01:19:03.140 |
But the gradients, I can just split across the number of devices that I have. 01:19:06.140 |
So half my gradients, if I have two GPUs, go on one device, 01:19:12.900 |
Half of them go on one device, half of them go on the other device. 01:19:17.460 |
you're actually not adding any extra communication cost. 01:19:20.540 |
You just get free memory because you're just intelligently partitioning 01:19:31.620 |
less memory as you add more and more machines, right? 01:19:34.780 |
So this is kind of like the biggest trick to start training 1 billion to 01:19:40.780 |
And now when you add this, you're at three days. 01:19:48.660 |
cuz this would require things being slightly out of sync to optimize. 01:19:53.100 |
Would you have to- >> So this actually doesn't require 01:19:55.060 |
anything out of sync, because in a backwards pass with distributed data 01:19:58.700 |
parallel, the individual updates already have to be synced across processes. 01:20:05.700 |
If you take the loss across average and across all processes, 01:20:08.860 |
that means that the gradients you apply have to be transferred as well. 01:20:12.420 |
So this just basically does that work for you. 01:20:20.100 |
You hit a communication wall, matrix multiplies, 01:20:22.300 |
stop fitting on a device, so you start charting them. 01:20:24.860 |
And then you have to start scheduling things wisely, and yeah, great. 01:20:27.980 |
Fine tuning inference, there is a great library that you should use 01:20:32.140 |
called Peft from Hugging Face, it's great, and that's it.