Back to Index

Stanford XCS224U: NLU I Contextual Word Representations, Part 7: ELECTRA I Spring 2023


Transcript

Welcome back everyone. This is part seven in our series on contextual representations. We're going to talk about the Electra model. Recall that I finished the BERT screencast by listing out some known limitations of that model. Roberta addressed item one on that list and we can think of Electra as keying into items two and three.

Item two is about the mask token. The BERT team observed that they had created a mismatch between the pre-training and fine tuning vocabularies because the mask token is never seen during fine tuning, only during training and you could think that that mismatch might reduce the effectiveness of the model.

Item three is about efficiency. The BERT team observed that the MLM objective means that they only use around 15% of tokens when they are training. Only 15% of them even contribute to the MLM objective. We have to do all this work of processing every item in the sequence, but we get very few learning signals from that process.

And that's certainly data inefficient and we might think about finding ways to make more use of the available data. Electra is going to make progress on both these fronts. Let's explore the core model structure. For our example, we have this input sequence X, the chef cooked the meal. The first thing we do is create X masked, which is a masked version of that input sequence.

And we could do that using the same protocol as they use for BERT by masking out, say, 15% of the tokens at random. Then we have our generator. This is a small BERT-like model that processes that input and produces what we call X corrupt. This is an output sequence predicted by the model.

And the twist here is that we're going to replace some of those tokens not with their original inputs, but rather with tokens that come out with probabilities proportional to the probability generators. And what that means is that sometimes we'll replace with the actual input token and sometimes with a different token, like in this case of cooked coming in, being replaced by eight.

That is where Electra, the discriminator, takes over. The job of the discriminator, which is really the heart of the Electra model, is to figure out which of those tokens in X corrupt is an original and which was replaced. So we train this model jointly with the generator and a weighted version of the discriminator or Electra objective.

And then, essentially, we can allow the generator to drop away and focus on the discriminator as the primary pre-trained artifact from this process. One thing that I really love about the Electra paper is that it includes very rich studies of how best to set up the Electra model itself.

I'll review some of that evidence here, starting with the relationship that they uncover between the generator and the discriminator. First thing is an observation. Where the generator and discriminator are the same size, they could, in principle, share their transformer parameters. And the team found that more sharing is indeed better.

However, the best results come from having a generator that is small compared to the discriminator, which means less sharing. Here's a chart summarizing their evidence for this. Along the x-axis, I have the generator size going up to 1024. And along the y-axis, we have GLU score, which will be our proxy for overall quality.

The blue line up here is the discriminator at size 768. And we're tracking different generator sizes, as I said. And you see this characteristic reverse U-shape, where, for example, the best discriminator at size 768 corresponds to a generator of size 256. And indeed, as the generator gets larger and even gets larger than the discriminator, performance drops off.

And that U-shape is repeated for all these different discriminator sizes, suggesting a real finding about the model. I think the intuition here is that it's kind of good to have a small and relatively weak generator so that the discriminator has a lot of interesting work to do, because after all, the discriminator is our focus.

The paper also includes a lot of efficiency studies. And those, too, are really illuminating. This is a summary of some of their evidence. Along the x-axis, we have pre-trained flops, which you can think of as a raw amount of overall compute needed for training. And along the y-axis, again, we have the GLUE score.

The blue line at the top here is the full Elektra model. And the core result here is that for any compute budget you have, that is any point along the x-axis, Elektra is the best model. It looks like in second place is adversarial Elektra. That's an intriguing variation of the model, where the generator is actually trained to try to fool the discriminator.

That's a clear intuition that turns out to be slightly less good than the more cooperative objective that I presented before. And then the green lines are intriguing as well. So for the green lines, we begin by training just in a standard BERT fashion. And then at a certain point, we switch over to the full Elektra model.

And what you see there is that in switching over to full Elektra, you get a gain in performance for any compute budget relative to the standard BERT training continuing as before, which is the lowest line in the chart. So a clear win for Elektra relative to these interesting competitors.

And they did further efficiency analyses. Let me review some of what they found there. This is the full Elektra model as I presented it before. We could also think about Elektra 15%. And this is the case where for the discriminator, instead of having it make predictions about all of the input tokens, we just zoom in on the tokens that were part of this x corrupt sequence, ignoring all the rest.

That's a very BERT-like intuition where the ones that matter were these ones that got masked down here. That makes fewer predictions for the discriminator. Replace MLM is where we use the generator with no discriminator. This is a kind of ablation of BERT. And then all tokens MLM is a kind of variant of BERT where instead of turning off the objective for some of the items, we make predictions about all of them.

And here's a summary of the evidence that they found in favor of Elektra. That's at the top here, according to the Glue score. All tokens MLM and replace MLM, those BERT variants are just behind. And that's sort of intriguing because it shows that even if we stick to the BERT architecture, we could have done better simply by making more predictions than BERT was making initially.

Behind those is Elektra 15%. And that shows that on the discriminator side, again, it pays to make more predictions. If we retreat to the more BERT-like mode where we predict only for the corrupted elements, we find that performance degrades. And then at the bottom of this list is the original BERT model showing a clear win overall for Elektra according to this Glue benchmark.

The Elektra team released three models initially, small, base, and large. Base and large kind of correspond roughly to BERT releases. And small is a tiny one that they say is designed to be quickly trained on a single GPU. Again, another nod toward increasing emphasis on efficiency for compute as an important ingredient in research in this space.

Thanks.