back to index

Stanford XCS224U: NLU I Contextual Word Representations, Part 9: Distillation I Spring 2023


Whisper Transcript | Transcript Only Page

00:00:00.000 | Welcome back everyone.
00:00:06.160 | This is part nine in our series on contextual representation.
00:00:09.520 | For part nine, we're going to switch gears a little bit and talk about distillation.
00:00:13.800 | The name of the game here is going to be efficiency.
00:00:16.460 | We are seeking models that are smaller and therefore more efficient to use,
00:00:21.520 | especially at inference time,
00:00:23.280 | but nonetheless very performant,
00:00:25.780 | and distillation is a set of techniques for achieving that.
00:00:29.520 | On the first day of the course,
00:00:32.100 | I had this slide that tracked model size over time for our large language models,
00:00:36.940 | and you saw it going up and up and up all the way to
00:00:39.440 | palm at like 540 billion parameters.
00:00:42.980 | Then I offered a hopeful perspective that models would start getting smaller.
00:00:47.320 | One perspective on why models might get smaller is that we can
00:00:51.800 | distill the essence of these really large models down into
00:00:55.740 | the small ones and therefore get models that are more efficient when deployed.
00:01:00.560 | The name of the game for distillation is that we have a teacher model that is
00:01:06.840 | presumably very good but also very large and therefore very expensive to use.
00:01:12.520 | The goal is to train a student model that has
00:01:16.160 | similar input-output behavior to the teacher,
00:01:19.740 | but is nonetheless much more efficient to use.
00:01:23.120 | We can do that in very lightweight ways that simply depend on having
00:01:27.060 | the student mimic the teacher in terms of its basic input-output behavior.
00:01:31.880 | But we can also think about going deeper and having it be the case that we train
00:01:36.540 | the student to have internal representations that are similar in some sense to
00:01:41.060 | those of the teacher to gain an even deeper distillation of that teacher.
00:01:46.580 | In that context, let's review some distillation objectives.
00:01:50.660 | What I've done here is list them out from least to most heavy duty.
00:01:55.100 | Of course, you'll commonly see that people take
00:01:57.820 | weighted averages of different elements of this list.
00:02:01.240 | For item 0 on the list,
00:02:03.800 | I just mentioned that you will probably distill your student by in
00:02:07.640 | part training it on gold data for the task
00:02:10.180 | if you have it available and can make use of it.
00:02:12.440 | We're talking essentially about supplementing
00:02:15.780 | that core training with additional components of the objective.
00:02:20.820 | The first distillation objective and the most lightweight one is that
00:02:25.380 | we simply train the student to produce the same output as the teacher.
00:02:29.880 | This is very lightweight because at distillation time,
00:02:32.900 | we actually don't require any direct access to the teacher.
00:02:36.100 | We simply run the teacher on all our available training data.
00:02:39.620 | It produces labels and then we train the student on those labels.
00:02:43.980 | It's a bit mysterious why that is useful.
00:02:47.260 | I think the guiding insight here is that there might be
00:02:50.460 | aspects of your training data that are noisy or just very difficult to learn.
00:02:55.540 | The teacher acts as a regularizer and the student benefits from seeing
00:03:00.660 | the teacher's output even if it contains
00:03:03.500 | some mistakes because that ultimately helps with generalization.
00:03:08.020 | Going one layer deeper,
00:03:10.320 | we could train the student to have similar output behavior as
00:03:13.600 | the teacher at the level of the full vector of output scores.
00:03:17.560 | This is in fact the centerpiece of one of
00:03:19.360 | the most famous distillation papers Hinton et al 2015.
00:03:23.460 | It's a little bit more heavy duty than just the output labels
00:03:26.920 | because we do require those entire score vectors,
00:03:29.800 | but it's still a purely behavioral distillation objective.
00:03:34.620 | Going one layer deeper in the famous Distilbert paper,
00:03:38.680 | Son et al 2019,
00:03:40.760 | in addition to having components that are like 1 and 2,
00:03:44.120 | their distillation objective also has a cosine loss component.
00:03:48.500 | Here what we're trying to do is have the teacher and
00:03:51.040 | student output states in the transformer sense be very similar to each other.
00:03:56.000 | This requires much more access to the teacher at distillation time because we need
00:04:01.100 | to do forward inference on the teacher for each example that we train
00:04:04.880 | the student on to get those output states and
00:04:08.120 | then apply the cosine loss and update the student.
00:04:12.520 | You could also think about tying other teacher and student states,
00:04:16.740 | other hidden states and maybe most prominently
00:04:19.040 | the embedding layers for the teacher and student model.
00:04:21.640 | Again, with an intuition that the models will be more alike and the student,
00:04:25.440 | therefore, more powerful if its internal representations mimic those of the teacher.
00:04:30.760 | Then maybe even more heavy duty,
00:04:32.960 | this is work that I was involved with,
00:04:34.720 | we now train the student to mimic the counterfactual behavior of the teacher under
00:04:40.040 | interventions that is instances in which we actually change the internal state of
00:04:44.480 | the teacher and do the same corresponding thing to
00:04:47.680 | the student and ensure that the two have matching input-output behavior.
00:04:51.880 | That's a more thorough exploration of the input-output behavior,
00:04:55.720 | putting the model into counterfactual states with the hope that it will
00:04:59.600 | lead the models to have very similar causal internal structure.
00:05:03.600 | For 3, 4, and 5,
00:05:06.040 | this is very heavy duty in the sense that we do require
00:05:08.640 | full access to the teacher at distillation time.
00:05:11.400 | But in all of these cases,
00:05:13.140 | I'm presuming that the teacher is a frozen artifact,
00:05:16.000 | and all you have to do is forward inference.
00:05:19.160 | There's another dimension to
00:05:21.560 | these distillation objectives that is worth thinking about.
00:05:24.200 | Again, these can be combined with each other
00:05:27.120 | and with the different modes that I just described.
00:05:30.160 | In standard distillation, the teacher is frozen,
00:05:33.760 | as I said before,
00:05:34.840 | and only the student parameters are updated.
00:05:38.080 | We could also think about multi-teacher distillation.
00:05:41.280 | In this case, we have multiple teachers,
00:05:43.160 | maybe with different capabilities,
00:05:44.700 | and we simultaneously try to distill them all down into a single student that can
00:05:49.620 | presumably perform multiple tasks coming from those teachers.
00:05:54.320 | Code distillation is really interestingly different to think about.
00:05:57.880 | In this case, the student and the teacher are trained jointly.
00:06:01.800 | This is sometimes also called online distillation.
00:06:04.920 | This is very heavy duty in the sense that you're training both of
00:06:07.760 | these artifacts simultaneously and it's hard to think about.
00:06:12.520 | Self-distillation is even harder to think about.
00:06:15.240 | In this case, the distillation objective includes terms that seek to make
00:06:19.640 | some model components align with others from the same model.
00:06:25.240 | In terms of performance,
00:06:28.320 | this is the name of the game.
00:06:29.840 | As I said before,
00:06:30.840 | we are seeking artifacts that are more efficient,
00:06:33.120 | but nonetheless still performant.
00:06:35.120 | I thought I would wrap up this short screencast by just summarizing what we
00:06:39.280 | know for the specific case of natural language understanding focused on glue.
00:06:44.000 | Based on the evidence,
00:06:45.440 | I think it's fair to say that we can distill
00:06:47.880 | BERT models down into much smaller models that are still highly performant.
00:06:52.960 | A lot of this research has used the glue benchmark to track
00:06:56.400 | this and it's all converging on the same insight.
00:06:59.340 | In the famous distill BERT paper,
00:07:01.300 | they took BERT base and distilled it down into six layers
00:07:04.960 | with 97 percent of the glue performance retained.
00:07:08.880 | Sun et al did something similar.
00:07:10.940 | They tried BERT based on into three layer and six layer,
00:07:14.380 | and also saw that they could maintain outstanding performance on glue.
00:07:18.940 | Similarly, Jow et al 2020 distilled BERT base into four layers,
00:07:23.460 | and again saw very strong results on glue.
00:07:26.680 | This set of results here is noteworthy because it's converging on the same lesson.
00:07:31.180 | We can make BERT much smaller by distilling down into
00:07:34.500 | a much smaller student that still does well on benchmarks like glue.
00:07:39.180 | That should be inspiring in terms of thinking about
00:07:41.500 | distillation as a powerful tool in your toolkit for taking very large and maybe
00:07:46.960 | expensive teachers and turning them into things that might
00:07:50.060 | have more practical utility out in the world.
00:07:54.380 | [BLANK_AUDIO]