Stanford XCS224U: NLU I Contextual Word Representations, Part 5: BERT I Spring 2023


0:0 Intro
0:18 BERT: Core model structure
2:7 Masked Language Modeling (MLM)
3:43 BERT: MLM loss function
4:59 Binary next sentence prediction pretraining
5:44 BERT: Transfer learning and fine-tuning
6:55 Tokenization and the BERT embedding space
7:41 BERT: Core model releases
9:27 BERT: Known limitations

00:00:00.000 | Welcome back everyone.
00:00:06.000 | This is part five in our series on contextual representations.
00:00:09.680 | We're going to focus on BERT.
00:00:11.320 | BERT is the slightly older sibling to GPT,
00:00:14.400 | but arguably just as important and famous.
00:00:18.280 | Let's start with the core model structure for BERT.
00:00:21.720 | This is mostly going to be
00:00:23.200 | combinations of familiar elements at this point,
00:00:26.020 | given that we've already reviewed
00:00:27.680 | the transformer architecture in BERT is essentially
00:00:30.440 | just an interesting use of the transformer.
00:00:33.720 | As usual to illustrate,
00:00:35.420 | I have our sequence, the rock rules at the bottom here,
00:00:38.680 | but that sequence is augmented in
00:00:40.700 | a bunch of BERT specific ways.
00:00:42.560 | All the way on the left here, we have a class token.
00:00:45.000 | That's an important token for the BERT architecture.
00:00:47.500 | Every sequence begins with the class token.
00:00:50.780 | That has a positional encoding.
00:00:52.620 | We also have a hierarchical positional encoding.
00:00:55.680 | This is given by the token sent A.
00:00:58.080 | This won't be so interesting for our illustration,
00:01:01.200 | but as I mentioned before,
00:01:03.080 | for problems like natural language inference,
00:01:05.340 | we might have a separate token for
00:01:07.080 | the premise and a separate one for
00:01:09.140 | the hypothesis to help encode the fact that
00:01:12.080 | a word appearing in the premise is
00:01:14.200 | a slightly different occurrence of
00:01:15.920 | that word than when it appears in a hypothesis.
00:01:19.040 | That generalizes to lots of
00:01:21.340 | different hierarchical position for
00:01:24.520 | different tasks that we might pose.
00:01:26.960 | But we have this very position
00:01:29.560 | sensitive encoding of our input sequence.
00:01:32.400 | We look up the embedding representations
00:01:34.920 | for all those pieces as usual,
00:01:37.060 | and then we do an additive combination of them to get
00:01:40.000 | our first context sensitive encoding
00:01:43.520 | of this input sequence in
00:01:44.880 | these vectors that are in green here.
00:01:47.200 | Then just as with GPT,
00:01:49.320 | we have lots of transformer blocks,
00:01:51.280 | potentially dozens of them repeated
00:01:54.040 | until we finally get to some output states,
00:01:56.560 | which I've given in dark green here.
00:01:58.620 | Those are going to be the basis for
00:02:00.400 | further things that we do with the model.
00:02:03.000 | That's the structure.
00:02:04.680 | Let's think about how we train this artifact.
00:02:07.360 | The core objective is masked language modeling or MLM.
00:02:12.200 | The idea here is essentially that we're going to mask out or
00:02:15.760 | obscure the identities of some words in the sequence,
00:02:19.820 | and then have the model try to
00:02:21.440 | reconstruct the missing piece.
00:02:23.760 | For our sequence, we could have a scenario where we have
00:02:26.520 | no masking on the word rules,
00:02:28.880 | but we nonetheless train the model
00:02:30.840 | to predict rules at that time step.
00:02:33.060 | That might be relatively easy as a reconstruction task.
00:02:36.540 | Harder, we'll be doing masking.
00:02:38.880 | In this case, we have a special designated token
00:02:41.840 | that we insert in the place of the token rules.
00:02:45.240 | Then we try to get the model to a state where it can
00:02:48.920 | reconstruct that rules was the missing piece
00:02:51.640 | using the full bidirectional context around that point.
00:02:55.680 | Then relatedly, in addition to masking,
00:02:57.920 | we could do random word replacement.
00:03:00.400 | In this case, we simply take the actual word,
00:03:03.040 | in this case rules,
00:03:04.180 | and replace it with a random one like every,
00:03:06.880 | and then try to have the model learn to predict
00:03:09.400 | what was the actual token at that position.
00:03:13.080 | All of these things are using
00:03:14.640 | the bidirectional context of
00:03:16.720 | the model in order to do this reconstruction task.
00:03:20.040 | When we train this model,
00:03:21.400 | we mask out only a small percentage of all the tokens,
00:03:24.800 | mostly leaving the other ones in place so that
00:03:27.720 | the model has lots of context to use to
00:03:30.040 | predict the masked or missing or corrupted tokens.
00:03:33.760 | That's actually a limitation of the model and
00:03:35.920 | if inefficiency in the MLM objective
00:03:38.540 | that Electra in particular will seek to address.
00:03:42.440 | Here's the MLM loss function in some detail.
00:03:45.920 | Again, as before with these loss functions,
00:03:48.400 | there are a lot of details here,
00:03:49.800 | but I think the crucial thing to
00:03:51.640 | zoom in on is first the numerator.
00:03:53.920 | It's very familiar from before.
00:03:55.680 | We're going to use the embedding representation
00:03:57.680 | of the token that we want to predict,
00:03:59.760 | and we're going to get a dot product of
00:04:01.600 | that with a model representation.
00:04:03.520 | In this case, we can use the entire surrounding context,
00:04:07.840 | leaving out only the representation at T.
00:04:10.960 | Whereas for the autoregressive objective
00:04:13.500 | that we reviewed before,
00:04:14.940 | we could only use
00:04:16.240 | the preceding context to make this prediction.
00:04:19.280 | The other thing to notice here is that we have
00:04:21.760 | this indicator function MT here,
00:04:24.280 | which is going to be one if we're looking at
00:04:26.680 | a masked token and zero otherwise.
00:04:30.000 | What that's essentially doing is turning off
00:04:32.560 | this objective for tokens that we didn't mask out.
00:04:35.840 | We get a learning signal only from
00:04:38.360 | the masked tokens or the ones that we have corrupted.
00:04:41.720 | That again feeds into a inefficiency of
00:04:46.040 | this objective because we in effect do
00:04:48.280 | the work of making predictions for all the time steps,
00:04:51.040 | but get an error signal for the loss function only for
00:04:54.480 | the ones that we have designated as masked in some sense.
00:04:59.040 | For the BERT paper,
00:05:01.240 | they supplemented the MLM objective with
00:05:03.660 | a binary next sentence prediction task.
00:05:06.800 | In this case, we use our corpus resources to create
00:05:09.920 | actual sentence sequences with
00:05:11.920 | all of their special tokens in them.
00:05:13.960 | For sequences that actually occurred in the corpus,
00:05:17.040 | we label them as next.
00:05:18.680 | Then for negative instances,
00:05:20.320 | we have randomly chosen sentences that we
00:05:22.600 | pair up and label them as not next.
00:05:25.400 | The motivation for this part of
00:05:27.280 | the objective is to help the model learn
00:05:29.640 | some discourse level information
00:05:32.280 | as part of learning how to reconstruct sequences.
00:05:34.880 | I think that's a really interesting intuition about how we
00:05:38.080 | might bring an even richer notions of
00:05:39.920 | context into the transformer representations.
00:05:44.280 | When we think about transfer learning or fine-tuning,
00:05:48.040 | there are a few different approaches that we can take.
00:05:50.440 | Here's a depiction of the transformer architecture.
00:05:53.600 | The standard lightweight thing to do is to build out
00:05:57.640 | task parameters on top of
00:05:59.600 | the final output representation above the class token.
00:06:03.400 | I think that works really well because
00:06:05.320 | the class token is used as
00:06:06.800 | the first token in every single sequence that BERT processes,
00:06:11.400 | and it's always in that fixed position.
00:06:13.840 | It becomes a constant element that contains
00:06:16.520 | a lot of information about the corresponding sequence.
00:06:20.440 | The standard thing is to build
00:06:22.160 | a few dense layers on top of that,
00:06:24.400 | and then maybe do some classification learning there.
00:06:27.680 | But of course, as with GPT,
00:06:29.600 | we shouldn't feel limited by that.
00:06:31.320 | A standard alternative to this would be to pool
00:06:34.200 | together all of the output states and then build
00:06:37.040 | the task parameters on top of that mean pooling or
00:06:40.360 | max pooling or whatever decision you use
00:06:43.000 | to bring together all of
00:06:45.840 | the output states to make predictions for your task.
00:06:48.680 | That can be very powerful as well because you bring in
00:06:51.600 | much more information about the entire sequence.
00:06:55.240 | I thought I would remind you a little bit
00:06:57.760 | about how tokenization works.
00:06:59.480 | Remember that BERT has this tiny vocabulary and
00:07:03.360 | therefore a tiny static embedding space.
00:07:06.640 | The reason it gets away with that is because it does
00:07:09.280 | word piece tokenization which means that we have lots of
00:07:12.560 | these word pieces indicated by these double hash marks here.
00:07:16.440 | That means that the model essentially never
00:07:18.960 | unks out any of its input tokens,
00:07:21.000 | but rather breaks them down into familiar pieces.
00:07:25.000 | Then the intuition is that the power of
00:07:28.480 | masked language modeling in particular will allow us to
00:07:31.680 | learn internal representations of things that
00:07:34.600 | correspond even to words like
00:07:36.560 | encode which got spread out over multiple tokens.
00:07:40.960 | Let's talk a little bit about core model releases.
00:07:45.120 | For the original BERT paper,
00:07:47.400 | I believe they just did BERT base and BERT
00:07:49.600 | large encased and uncased variants.
00:07:52.200 | I would recommend always using the cased ones at this point.
00:07:55.280 | Very happily, lots of teams including
00:07:57.920 | the Google team have worked to develop even smaller ones.
00:08:00.680 | We have tiny, mini, small, and medium as well.
00:08:04.200 | This is really welcome because it means you can do a lot of
00:08:07.160 | development on these tiny models
00:08:09.240 | and then possibly scale up to larger ones.
00:08:11.720 | For example, BERT tiny has just two layers,
00:08:15.080 | that is two transformer blocks,
00:08:17.080 | relatively small model dimensionality and
00:08:19.600 | relatively small expansion inside its feed-forward layer
00:08:23.160 | for a total number of parameters of only four million.
00:08:26.320 | I will say that that is tiny,
00:08:28.240 | but it's surprising how much juice you can get
00:08:30.600 | out of it when you fine-tune it for tasks.
00:08:33.120 | But then you can move on up to mini, small, medium,
00:08:36.080 | and then large is the largest from
00:08:38.320 | the original release at 24 layers,
00:08:40.720 | relatively large model dimensionality,
00:08:43.240 | relatively large feed-forward layer for
00:08:45.720 | a total number of parameters of around 340 million.
00:08:50.360 | All of these models, because all of them,
00:08:53.960 | as far as I know,
00:08:55.120 | use absolute positional embeddings,
00:08:57.760 | have a maximum sequence length of 512.
00:09:00.920 | That's an important limitation that increasingly we're
00:09:03.760 | feeling is constraining the kinds of
00:09:05.520 | work we can do with models like BERT.
00:09:08.560 | There are many new releases,
00:09:10.480 | and I would say to stay up to date,
00:09:12.080 | you could check out Hugging Face,
00:09:13.440 | which has variants of these models for
00:09:15.360 | different languages and maybe some different
00:09:17.120 | sizes and other kinds of things.
00:09:19.200 | Maybe, for example, there are by now versions that use
00:09:22.040 | relative positional encoding which
00:09:23.720 | would be quite welcome, I would say.
00:09:26.600 | For BERT, some known limitations,
00:09:29.120 | and this will feed into
00:09:30.320 | subsequent things that we want to talk about
00:09:32.160 | with Roberta and Elektra especially.
00:09:34.840 | First, the original BERT paper is admirably detailed,
00:09:39.320 | but it's still very partial in terms of
00:09:42.080 | ablation studies and studies of how to
00:09:44.280 | effectively optimize the model.
00:09:46.600 | That means that we might not be looking at
00:09:49.160 | the very best BERT that we could possibly
00:09:51.560 | have if we explored more widely.
00:09:54.600 | Devlin et al also observe a downside.
00:09:58.120 | They say the first downside is that we're creating
00:10:00.400 | a mismatch between pre-training and fine-tuning
00:10:03.440 | since the mask token is never seen during fine-tuning.
00:10:06.800 | That is indeed unusual.
00:10:08.480 | Remember, the mask token is a crucial element
00:10:11.120 | in training the model against the MLM objective.
00:10:14.200 | You introduce this foreign element into that phase that
00:10:17.560 | presumably you never see when you do fine-tuning,
00:10:20.760 | and that could be dragging down model performance.
00:10:24.560 | The second downside that they
00:10:26.360 | mentioned is one that I mentioned as well.
00:10:28.400 | We're using only around 15 percent of
00:10:31.240 | the tokens to make predictions.
00:10:33.400 | We do all this work of processing these sequences,
00:10:35.800 | but then we turn off the modeling objective
00:10:38.440 | for the tokens that we didn't mask,
00:10:40.320 | and we can mask only a tiny number of them because we
00:10:43.440 | need the bidirectional context to do the reconstruction.
00:10:46.480 | That's the essence of the intuition there.
00:10:49.120 | That's obviously inefficient.
00:10:51.680 | The final one is intriguing.
00:10:53.520 | I'll mention this only at the end of this series.
00:10:55.640 | This comes from the ExcelNet paper,
00:10:58.120 | and they just observed that BERT assumes
00:10:59.800 | the predicted tokens are independent of each other,
00:11:02.640 | given the unmasked tokens,
00:11:04.600 | which is oversimplified as high-order,
00:11:07.160 | long-range dependency is prevalent in natural language.
00:11:10.720 | This is just the observation that if you do happen to mask out
00:11:14.040 | two tokens like new and York from the place named New York,
00:11:18.040 | the model will try to reconstruct
00:11:19.960 | those two tokens independently of each other,
00:11:22.520 | even though we can see that they have
00:11:24.280 | a very clear statistical dependency.
00:11:26.960 | The BERT objective simply misses that,
00:11:29.240 | and I'll mention later on about how ExcelNet brings
00:11:32.920 | that dependency back in possibly to very powerful effect.
00:11:37.640 | [BLANK_AUDIO]