back to indexStanford XCS224U: NLU I Contextual Word Representations, Part 9: Distillation I Spring 2023
00:00:06.160 |
This is part nine in our series on contextual representation. 00:00:09.520 |
For part nine, we're going to switch gears a little bit and talk about distillation. 00:00:13.800 |
The name of the game here is going to be efficiency. 00:00:16.460 |
We are seeking models that are smaller and therefore more efficient to use, 00:00:25.780 |
and distillation is a set of techniques for achieving that. 00:00:32.100 |
I had this slide that tracked model size over time for our large language models, 00:00:36.940 |
and you saw it going up and up and up all the way to 00:00:42.980 |
Then I offered a hopeful perspective that models would start getting smaller. 00:00:47.320 |
One perspective on why models might get smaller is that we can 00:00:51.800 |
distill the essence of these really large models down into 00:00:55.740 |
the small ones and therefore get models that are more efficient when deployed. 00:01:00.560 |
The name of the game for distillation is that we have a teacher model that is 00:01:06.840 |
presumably very good but also very large and therefore very expensive to use. 00:01:12.520 |
The goal is to train a student model that has 00:01:16.160 |
similar input-output behavior to the teacher, 00:01:19.740 |
but is nonetheless much more efficient to use. 00:01:23.120 |
We can do that in very lightweight ways that simply depend on having 00:01:27.060 |
the student mimic the teacher in terms of its basic input-output behavior. 00:01:31.880 |
But we can also think about going deeper and having it be the case that we train 00:01:36.540 |
the student to have internal representations that are similar in some sense to 00:01:41.060 |
those of the teacher to gain an even deeper distillation of that teacher. 00:01:46.580 |
In that context, let's review some distillation objectives. 00:01:50.660 |
What I've done here is list them out from least to most heavy duty. 00:01:55.100 |
Of course, you'll commonly see that people take 00:01:57.820 |
weighted averages of different elements of this list. 00:02:03.800 |
I just mentioned that you will probably distill your student by in 00:02:10.180 |
if you have it available and can make use of it. 00:02:12.440 |
We're talking essentially about supplementing 00:02:15.780 |
that core training with additional components of the objective. 00:02:20.820 |
The first distillation objective and the most lightweight one is that 00:02:25.380 |
we simply train the student to produce the same output as the teacher. 00:02:29.880 |
This is very lightweight because at distillation time, 00:02:32.900 |
we actually don't require any direct access to the teacher. 00:02:36.100 |
We simply run the teacher on all our available training data. 00:02:39.620 |
It produces labels and then we train the student on those labels. 00:02:47.260 |
I think the guiding insight here is that there might be 00:02:50.460 |
aspects of your training data that are noisy or just very difficult to learn. 00:02:55.540 |
The teacher acts as a regularizer and the student benefits from seeing 00:03:03.500 |
some mistakes because that ultimately helps with generalization. 00:03:10.320 |
we could train the student to have similar output behavior as 00:03:13.600 |
the teacher at the level of the full vector of output scores. 00:03:19.360 |
the most famous distillation papers Hinton et al 2015. 00:03:23.460 |
It's a little bit more heavy duty than just the output labels 00:03:26.920 |
because we do require those entire score vectors, 00:03:29.800 |
but it's still a purely behavioral distillation objective. 00:03:34.620 |
Going one layer deeper in the famous Distilbert paper, 00:03:40.760 |
in addition to having components that are like 1 and 2, 00:03:44.120 |
their distillation objective also has a cosine loss component. 00:03:48.500 |
Here what we're trying to do is have the teacher and 00:03:51.040 |
student output states in the transformer sense be very similar to each other. 00:03:56.000 |
This requires much more access to the teacher at distillation time because we need 00:04:01.100 |
to do forward inference on the teacher for each example that we train 00:04:04.880 |
the student on to get those output states and 00:04:08.120 |
then apply the cosine loss and update the student. 00:04:12.520 |
You could also think about tying other teacher and student states, 00:04:16.740 |
other hidden states and maybe most prominently 00:04:19.040 |
the embedding layers for the teacher and student model. 00:04:21.640 |
Again, with an intuition that the models will be more alike and the student, 00:04:25.440 |
therefore, more powerful if its internal representations mimic those of the teacher. 00:04:34.720 |
we now train the student to mimic the counterfactual behavior of the teacher under 00:04:40.040 |
interventions that is instances in which we actually change the internal state of 00:04:44.480 |
the teacher and do the same corresponding thing to 00:04:47.680 |
the student and ensure that the two have matching input-output behavior. 00:04:51.880 |
That's a more thorough exploration of the input-output behavior, 00:04:55.720 |
putting the model into counterfactual states with the hope that it will 00:04:59.600 |
lead the models to have very similar causal internal structure. 00:05:06.040 |
this is very heavy duty in the sense that we do require 00:05:08.640 |
full access to the teacher at distillation time. 00:05:13.140 |
I'm presuming that the teacher is a frozen artifact, 00:05:21.560 |
these distillation objectives that is worth thinking about. 00:05:27.120 |
and with the different modes that I just described. 00:05:30.160 |
In standard distillation, the teacher is frozen, 00:05:38.080 |
We could also think about multi-teacher distillation. 00:05:44.700 |
and we simultaneously try to distill them all down into a single student that can 00:05:49.620 |
presumably perform multiple tasks coming from those teachers. 00:05:54.320 |
Code distillation is really interestingly different to think about. 00:05:57.880 |
In this case, the student and the teacher are trained jointly. 00:06:01.800 |
This is sometimes also called online distillation. 00:06:04.920 |
This is very heavy duty in the sense that you're training both of 00:06:07.760 |
these artifacts simultaneously and it's hard to think about. 00:06:12.520 |
Self-distillation is even harder to think about. 00:06:15.240 |
In this case, the distillation objective includes terms that seek to make 00:06:19.640 |
some model components align with others from the same model. 00:06:30.840 |
we are seeking artifacts that are more efficient, 00:06:35.120 |
I thought I would wrap up this short screencast by just summarizing what we 00:06:39.280 |
know for the specific case of natural language understanding focused on glue. 00:06:47.880 |
BERT models down into much smaller models that are still highly performant. 00:06:52.960 |
A lot of this research has used the glue benchmark to track 00:06:56.400 |
this and it's all converging on the same insight. 00:07:01.300 |
they took BERT base and distilled it down into six layers 00:07:04.960 |
with 97 percent of the glue performance retained. 00:07:10.940 |
They tried BERT based on into three layer and six layer, 00:07:14.380 |
and also saw that they could maintain outstanding performance on glue. 00:07:18.940 |
Similarly, Jow et al 2020 distilled BERT base into four layers, 00:07:26.680 |
This set of results here is noteworthy because it's converging on the same lesson. 00:07:31.180 |
We can make BERT much smaller by distilling down into 00:07:34.500 |
a much smaller student that still does well on benchmarks like glue. 00:07:39.180 |
That should be inspiring in terms of thinking about 00:07:41.500 |
distillation as a powerful tool in your toolkit for taking very large and maybe 00:07:46.960 |
expensive teachers and turning them into things that might 00:07:50.060 |
have more practical utility out in the world.