back to indexStanford XCS224U: NLU I Analysis Methods for NLU, Part 2: Probing I Spring 2023
Chapters
0:0 Intro
0:13 Overview
1:54 Recipe for probing
2:52 Core method
4:42 Probing or learning a new model?
5:41 Control tasks and probe selectivity
7:21 Simple example
9:44 From probing to multi-task training
10:37 Unsupervised probes
11:3 Summary
00:00:06.100 |
This is part 2 in our series on analysis methods for NLP. 00:00:10.140 |
We've come to our first method and that is probing. 00:00:15.380 |
The core idea is that we're going to use supervised models, 00:00:22.560 |
the hidden representations of our target models, 00:00:27.320 |
Probing is often applied in the context of so-called Bertology, 00:00:33.520 |
a really foundational contribution in this space. 00:00:37.200 |
I think this was really eye-opening about the extent to which Bert 00:00:46.100 |
Probing can be a source of valuable insights, I believe, 00:00:51.680 |
and there are really two cautionary notes here. 00:00:57.160 |
to see things that aren't in your target model, 00:01:02.520 |
It is after all a supervised model that you trained in some way. 00:01:06.280 |
Second, and maybe more importantly for the current unit, 00:01:10.000 |
probes cannot tell us about whether the information that we identify 00:01:13.640 |
has any causal relationship with the target models input-output behavior. 00:01:18.240 |
This is really concerning for me because what we're looking for from 00:01:24.320 |
the causal mechanisms that guide model behaviors. 00:01:27.680 |
If probing falls short on offering us those causal insights, 00:01:31.640 |
it's really intrinsically limited as an analysis method. 00:01:41.120 |
but I will mention unsupervised probes near the end. 00:01:44.800 |
They don't suffer from the concern that they're overly powerful, 00:01:49.480 |
still fall short when it comes to offering causal insights. 00:01:53.680 |
Let's start with a recipe for probing to be careful about this. 00:01:58.680 |
The first step is that you state a hypothesis 00:02:01.280 |
about an aspect of the target model's internal structure. 00:02:04.360 |
You could hypothesize that it stores information about part of 00:02:07.840 |
speech or named entities or dependency parses. 00:02:18.440 |
that is a proxy for the internal structure of interest. 00:02:24.880 |
and you're going to be dependent on that dataset when it 00:02:38.280 |
and you train a supervised probe on the chosen site. 00:02:41.960 |
Then the extent to which your probe is successful is 00:02:45.700 |
your estimate of the degree to which you were 00:02:54.800 |
What I have on the slide now is a very cartoonish look at 00:02:58.920 |
a BERT-like model with three layers and you can see 00:03:07.920 |
Let's suppose that's the site that we chose to probe. 00:03:11.400 |
What we're going to do is fit a small linear model on 00:03:15.360 |
that internal representation using some task labels. 00:03:19.280 |
The way that actually plays out in practice is 00:03:24.080 |
the BERT model on the current input and we're going to 00:03:26.720 |
grab the vector representation there and use it to start 00:03:30.720 |
building a little supervised learning dataset where this is 00:03:34.660 |
some vector and this is a task label for our input example. 00:03:39.160 |
Then we run the BERT model again on a different sequence. 00:03:42.700 |
We get a different vector representation at our target site, 00:03:48.200 |
our supervised learning dataset with a new task label. 00:03:58.880 |
We continue this process for maybe tens of thousands of examples, 00:04:02.600 |
whatever we've got available to us in our probe dataset. 00:04:06.280 |
Then we fit a small linear model on this XY pair. 00:04:10.900 |
Notice that we have used the BERT model simply as a engine for 00:04:28.160 |
You could decide that the entire layer here encodes part of speech, 00:04:31.360 |
and then you would build up a dataset consisting of lists of 00:04:34.040 |
these vectors with their associated lists of labels and 00:04:37.040 |
train a part of speech tagging model on that basis, 00:04:41.920 |
The first question that arises for probing is really pressing. 00:04:48.560 |
simply learning a new model that is the probe model? 00:04:55.360 |
frozen parameters of the model we're probing. 00:04:58.440 |
We use the BERT model as a engine for creating 00:05:01.720 |
these feature representations that were the input 00:05:06.920 |
This is very hard to distinguish from simply fitting 00:05:14.600 |
the site that we chose based on how BERT did its calculations. 00:05:31.760 |
might find more information in the target model, 00:05:43.520 |
Hewitt and Liang introduced the notion of probe selectivity. 00:05:47.280 |
This is just going to help us calibrate to some extent 00:05:50.160 |
how much information was actually in the target model. 00:05:53.760 |
The first step here is to define a control task. 00:05:58.760 |
the same input-output structure as your target task. 00:06:20.680 |
randomly using some simple strategies to give you 00:06:25.240 |
presumably from the ones in your gold dataset. 00:06:31.600 |
just the difference between probe performance on 00:06:34.220 |
the task and probe performance on the control task. 00:06:49.300 |
the most reliable probes in terms of giving you insights, 00:06:58.900 |
There is likely to be a very large difference 00:07:09.360 |
a very powerful probe model with many parameters, 00:07:12.760 |
you'll have low selectivity because that model has 00:07:29.600 |
We imagine that we have a small neural network that takes 00:07:33.080 |
in three numbers as inputs and perfectly computes their sum. 00:07:42.120 |
We'll presume that it does that calculation perfectly 00:07:47.960 |
The question is, how does it manage this feat? 00:07:59.480 |
come together to form an intermediate variable S1. 00:08:02.840 |
The third one is copied into an internal state w, 00:08:06.460 |
and then S1 and w are modular representations that are 00:08:10.140 |
added together to form the output representation. 00:08:13.860 |
That's a hypothesis about how this model might work. 00:08:22.340 |
Let's suppose we have this neural network and what we decide is 00:08:35.660 |
L1 always perfectly encodes the identity of the third input. 00:08:42.280 |
we probe L2 and we find that it always perfectly 00:08:45.540 |
computes x plus y according to our very simple probe model. 00:08:54.360 |
You say, "Aha, it's a bit counterintuitive because L1 00:09:01.740 |
but nonetheless, the model is obeying my hypothesis." 00:09:09.280 |
Here is a look at the full internal structure of this model. 00:09:13.860 |
Again, this model performs our task perfectly, 00:09:21.940 |
One way to see that is to look at the output vector of weights, 00:09:25.540 |
L2 is just zeroed out as part of this computation, 00:09:35.940 |
but not in a way that tells us about the input-output behavior. 00:09:49.940 |
models from the analysis method that we've chosen? 00:09:56.660 |
probing to what you might call multi-task training, 00:09:59.340 |
where I'm training this model to do addition, 00:10:01.820 |
and in addition, I train it so that this representation 00:10:05.220 |
here encodes z and this one encodes x plus y. 00:10:11.100 |
I think it's an open question whether or not it actually 00:10:13.820 |
induces the modularity that we're interested in. 00:10:17.820 |
But the really deep concern for me is just that 00:10:27.060 |
the structure we induced, whatever it's like, 00:10:29.500 |
is actually shaping performance on the core task, 00:10:36.980 |
Finally, a quick note, I mentioned unsupervised probes. 00:10:58.540 |
I think, suffer that limitation about causal inference. 00:11:17.780 |
improved models because it's unclear to me whether 00:11:20.540 |
multi-task training is really a viable general way of