Back to Index

Stanford XCS224U: NLU I Contextual Word Representations, Part 9: Distillation I Spring 2023


Transcript

Welcome back everyone. This is part nine in our series on contextual representation. For part nine, we're going to switch gears a little bit and talk about distillation. The name of the game here is going to be efficiency. We are seeking models that are smaller and therefore more efficient to use, especially at inference time, but nonetheless very performant, and distillation is a set of techniques for achieving that.

On the first day of the course, I had this slide that tracked model size over time for our large language models, and you saw it going up and up and up all the way to palm at like 540 billion parameters. Then I offered a hopeful perspective that models would start getting smaller.

One perspective on why models might get smaller is that we can distill the essence of these really large models down into the small ones and therefore get models that are more efficient when deployed. The name of the game for distillation is that we have a teacher model that is presumably very good but also very large and therefore very expensive to use.

The goal is to train a student model that has similar input-output behavior to the teacher, but is nonetheless much more efficient to use. We can do that in very lightweight ways that simply depend on having the student mimic the teacher in terms of its basic input-output behavior. But we can also think about going deeper and having it be the case that we train the student to have internal representations that are similar in some sense to those of the teacher to gain an even deeper distillation of that teacher.

In that context, let's review some distillation objectives. What I've done here is list them out from least to most heavy duty. Of course, you'll commonly see that people take weighted averages of different elements of this list. For item 0 on the list, I just mentioned that you will probably distill your student by in part training it on gold data for the task if you have it available and can make use of it.

We're talking essentially about supplementing that core training with additional components of the objective. The first distillation objective and the most lightweight one is that we simply train the student to produce the same output as the teacher. This is very lightweight because at distillation time, we actually don't require any direct access to the teacher.

We simply run the teacher on all our available training data. It produces labels and then we train the student on those labels. It's a bit mysterious why that is useful. I think the guiding insight here is that there might be aspects of your training data that are noisy or just very difficult to learn.

The teacher acts as a regularizer and the student benefits from seeing the teacher's output even if it contains some mistakes because that ultimately helps with generalization. Going one layer deeper, we could train the student to have similar output behavior as the teacher at the level of the full vector of output scores.

This is in fact the centerpiece of one of the most famous distillation papers Hinton et al 2015. It's a little bit more heavy duty than just the output labels because we do require those entire score vectors, but it's still a purely behavioral distillation objective. Going one layer deeper in the famous Distilbert paper, Son et al 2019, in addition to having components that are like 1 and 2, their distillation objective also has a cosine loss component.

Here what we're trying to do is have the teacher and student output states in the transformer sense be very similar to each other. This requires much more access to the teacher at distillation time because we need to do forward inference on the teacher for each example that we train the student on to get those output states and then apply the cosine loss and update the student.

You could also think about tying other teacher and student states, other hidden states and maybe most prominently the embedding layers for the teacher and student model. Again, with an intuition that the models will be more alike and the student, therefore, more powerful if its internal representations mimic those of the teacher.

Then maybe even more heavy duty, this is work that I was involved with, we now train the student to mimic the counterfactual behavior of the teacher under interventions that is instances in which we actually change the internal state of the teacher and do the same corresponding thing to the student and ensure that the two have matching input-output behavior.

That's a more thorough exploration of the input-output behavior, putting the model into counterfactual states with the hope that it will lead the models to have very similar causal internal structure. For 3, 4, and 5, this is very heavy duty in the sense that we do require full access to the teacher at distillation time.

But in all of these cases, I'm presuming that the teacher is a frozen artifact, and all you have to do is forward inference. There's another dimension to these distillation objectives that is worth thinking about. Again, these can be combined with each other and with the different modes that I just described.

In standard distillation, the teacher is frozen, as I said before, and only the student parameters are updated. We could also think about multi-teacher distillation. In this case, we have multiple teachers, maybe with different capabilities, and we simultaneously try to distill them all down into a single student that can presumably perform multiple tasks coming from those teachers.

Code distillation is really interestingly different to think about. In this case, the student and the teacher are trained jointly. This is sometimes also called online distillation. This is very heavy duty in the sense that you're training both of these artifacts simultaneously and it's hard to think about. Self-distillation is even harder to think about.

In this case, the distillation objective includes terms that seek to make some model components align with others from the same model. In terms of performance, this is the name of the game. As I said before, we are seeking artifacts that are more efficient, but nonetheless still performant. I thought I would wrap up this short screencast by just summarizing what we know for the specific case of natural language understanding focused on glue.

Based on the evidence, I think it's fair to say that we can distill BERT models down into much smaller models that are still highly performant. A lot of this research has used the glue benchmark to track this and it's all converging on the same insight. In the famous distill BERT paper, they took BERT base and distilled it down into six layers with 97 percent of the glue performance retained.

Sun et al did something similar. They tried BERT based on into three layer and six layer, and also saw that they could maintain outstanding performance on glue. Similarly, Jow et al 2020 distilled BERT base into four layers, and again saw very strong results on glue. This set of results here is noteworthy because it's converging on the same lesson.

We can make BERT much smaller by distilling down into a much smaller student that still does well on benchmarks like glue. That should be inspiring in terms of thinking about distillation as a powerful tool in your toolkit for taking very large and maybe expensive teachers and turning them into things that might have more practical utility out in the world.