Stanford XCS224U: NLU I Analysis Methods for NLU, Part 2: Probing I Spring 2023


0:0 Intro
0:13 Overview
1:54 Recipe for probing
2:52 Core method
4:42 Probing or learning a new model?
5:41 Control tasks and probe selectivity
7:21 Simple example
9:44 From probing to multi-task training
10:37 Unsupervised probes
11:3 Summary

Whisper Transcript

00:00:00.000 | Welcome back everyone.
00:00:06.100 | This is part 2 in our series on analysis methods for NLP.
00:00:10.140 | We've come to our first method and that is probing.
00:00:13.040 | Here's an overview of how probing works.
00:00:15.380 | The core idea is that we're going to use supervised models,
00:00:18.580 | those are our probe models,
00:00:20.120 | to determine what is latently encoded in
00:00:22.560 | the hidden representations of our target models,
00:00:25.380 | the ones that we actually care about.
00:00:27.320 | Probing is often applied in the context of so-called Bertology,
00:00:31.520 | and I think Tenny et al 2019 is
00:00:33.520 | a really foundational contribution in this space.
00:00:36.040 | As I mentioned before,
00:00:37.200 | I think this was really eye-opening about the extent to which Bert
00:00:40.480 | is inducing interesting structure about
00:00:42.960 | language from its training regimes.
00:00:46.100 | Probing can be a source of valuable insights, I believe,
00:00:49.880 | but we do need to proceed with caution,
00:00:51.680 | and there are really two cautionary notes here.
00:00:53.880 | First, a very powerful probe might lead you
00:00:57.160 | to see things that aren't in your target model,
00:00:59.560 | but rather just stored in your probe model.
00:01:02.520 | It is after all a supervised model that you trained in some way.
00:01:06.280 | Second, and maybe more importantly for the current unit,
00:01:10.000 | probes cannot tell us about whether the information that we identify
00:01:13.640 | has any causal relationship with the target models input-output behavior.
00:01:18.240 | This is really concerning for me because what we're looking for from
00:01:21.960 | analysis methods is insights about
00:01:24.320 | the causal mechanisms that guide model behaviors.
00:01:27.680 | If probing falls short on offering us those causal insights,
00:01:31.640 | it's really intrinsically limited as an analysis method.
00:01:36.080 | I'm going to focus for this screencast
00:01:39.000 | on supervised probes to keep things simple,
00:01:41.120 | but I will mention unsupervised probes near the end.
00:01:44.800 | They don't suffer from the concern that they're overly powerful,
00:01:48.200 | but they do, I think,
00:01:49.480 | still fall short when it comes to offering causal insights.
00:01:53.680 | Let's start with a recipe for probing to be careful about this.
00:01:58.680 | The first step is that you state a hypothesis
00:02:01.280 | about an aspect of the target model's internal structure.
00:02:04.360 | You could hypothesize that it stores information about part of
00:02:07.840 | speech or named entities or dependency parses.
00:02:11.720 | You name it, the hypothesis space is open.
00:02:15.040 | You then need to choose a supervised task
00:02:18.440 | that is a proxy for the internal structure of interest.
00:02:21.360 | If you're going to look for part of speech,
00:02:23.140 | you need a part of speech dataset,
00:02:24.880 | and you're going to be dependent on that dataset when it
00:02:27.400 | comes to actually defining the probe itself.
00:02:30.440 | Then you identify a place in the model,
00:02:33.640 | a set of hidden representations where
00:02:35.720 | you believe the structure will be encoded,
00:02:38.280 | and you train a supervised probe on the chosen site.
00:02:41.960 | Then the extent to which your probe is successful is
00:02:45.700 | your estimate of the degree to which you were
00:02:47.640 | right about the underlying hypothesis.
00:02:50.680 | But there are some caveats there.
00:02:52.720 | Let's first walk through the core method.
00:02:54.800 | What I have on the slide now is a very cartoonish look at
00:02:58.920 | a BERT-like model with three layers and you can see
00:03:02.000 | these inputs have come in and we're going to
00:03:04.840 | target the hidden representation H to start.
00:03:07.920 | Let's suppose that's the site that we chose to probe.
00:03:11.400 | What we're going to do is fit a small linear model on
00:03:15.360 | that internal representation using some task labels.
00:03:19.280 | The way that actually plays out in practice is
00:03:21.920 | instructive. We're going to run
00:03:24.080 | the BERT model on the current input and we're going to
00:03:26.720 | grab the vector representation there and use it to start
00:03:30.720 | building a little supervised learning dataset where this is
00:03:34.660 | some vector and this is a task label for our input example.
00:03:39.160 | Then we run the BERT model again on a different sequence.
00:03:42.700 | We get a different vector representation at our target site,
00:03:46.520 | and that also contributes to
00:03:48.200 | our supervised learning dataset with a new task label.
00:03:51.360 | We do it again for a different input.
00:03:53.960 | We get a different vector and
00:03:55.980 | another task label and so forth and so on.
00:03:58.880 | We continue this process for maybe tens of thousands of examples,
00:04:02.600 | whatever we've got available to us in our probe dataset.
00:04:06.280 | Then we fit a small linear model on this XY pair.
00:04:10.900 | Notice that we have used the BERT model simply as a engine for
00:04:15.820 | grabbing these vector representations
00:04:18.560 | that we use for our probe model.
00:04:21.000 | Of course, I chose a single representation,
00:04:25.240 | but more commonly with BERT,
00:04:26.600 | we're doing this layer-wise.
00:04:28.160 | You could decide that the entire layer here encodes part of speech,
00:04:31.360 | and then you would build up a dataset consisting of lists of
00:04:34.040 | these vectors with their associated lists of labels and
00:04:37.040 | train a part of speech tagging model on that basis,
00:04:39.920 | and that would be your probe.
00:04:41.920 | The first question that arises for probing is really pressing.
00:04:46.160 | Are we probing the target model or are we
00:04:48.560 | simply learning a new model that is the probe model?
00:04:51.320 | Probes in the current sense are
00:04:53.320 | supervised models whose inputs are
00:04:55.360 | frozen parameters of the model we're probing.
00:04:58.440 | We use the BERT model as a engine for creating
00:05:01.720 | these feature representations that were the input
00:05:03.960 | to a separate modeling process.
00:05:06.920 | This is very hard to distinguish from simply fitting
00:05:09.820 | a supervised model as usual with
00:05:11.880 | some particular choice of featurization,
00:05:14.600 | the site that we chose based on how BERT did its calculations.
00:05:19.000 | Based on 1 and 2,
00:05:21.080 | we know that at least some of
00:05:22.880 | the information that we're identifying is
00:05:25.220 | likely stored in the probe model,
00:05:27.120 | not in the target model.
00:05:29.120 | Of course, more powerful probes
00:05:31.760 | might find more information in the target model,
00:05:34.480 | but that's only because they're storing
00:05:36.480 | more information in the probe parameters.
00:05:38.380 | They have a greater capacity to do that.
00:05:41.640 | To help address this,
00:05:43.520 | Hewitt and Liang introduced the notion of probe selectivity.
00:05:47.280 | This is just going to help us calibrate to some extent
00:05:50.160 | how much information was actually in the target model.
00:05:53.760 | The first step here is to define a control task.
00:05:56.620 | This would be a random task with
00:05:58.760 | the same input-output structure as your target task.
00:06:02.160 | For example, for word sense classification,
00:06:04.440 | you could just assign words,
00:06:05.960 | random fixed senses.
00:06:08.440 | For part of speech tagging,
00:06:09.980 | you could assign words to random fixed tags,
00:06:12.800 | maybe keeping the same tag distribution
00:06:15.320 | as your underlying part of speech dataset.
00:06:18.040 | Or for parsing, you could assign edges
00:06:20.680 | randomly using some simple strategies to give you
00:06:23.360 | tree structures that are very different
00:06:25.240 | presumably from the ones in your gold dataset.
00:06:28.240 | Then selectivity as a metric for probes is
00:06:31.600 | just the difference between probe performance on
00:06:34.220 | the task and probe performance on the control task.
00:06:38.040 | You've baked in how well your model
00:06:40.560 | can do on a random task. That's the idea.
00:06:44.320 | Hewitt and Liang offer this summary picture,
00:06:47.420 | which essentially shows that
00:06:49.300 | the most reliable probes in terms of giving you insights,
00:06:52.320 | will be very small ones here.
00:06:54.260 | This is a model with just two hidden units.
00:06:56.920 | That gives you very high selectivity.
00:06:58.900 | There is likely to be a very large difference
00:07:01.640 | between performance on your task and
00:07:03.780 | the performance of this control model
00:07:05.680 | when the model is very simple.
00:07:07.720 | On the other hand, if you have
00:07:09.360 | a very powerful probe model with many parameters,
00:07:12.760 | you'll have low selectivity because that model has
00:07:15.280 | such a great capacity to simply
00:07:17.440 | memorize aspects of the dataset.
00:07:20.680 | Let's move now to the second concern I have,
00:07:24.440 | which is about causal inference.
00:07:26.040 | To build this argument,
00:07:27.480 | let's use a simple example.
00:07:29.600 | We imagine that we have a small neural network that takes
00:07:33.080 | in three numbers as inputs and perfectly computes their sum.
00:07:37.080 | When 1, 3, 5 comes in,
00:07:38.880 | it does its internal magic and it outputs 9.
00:07:42.120 | We'll presume that it does that calculation perfectly
00:07:44.880 | for all triples of integers coming in.
00:07:47.960 | The question is, how does it manage this feat?
00:07:51.120 | How does this model work?
00:07:52.600 | You might have a hypothesis that it
00:07:55.000 | does it in a compositional way,
00:07:57.240 | where the first two inputs, x and y,
00:07:59.480 | come together to form an intermediate variable S1.
00:08:02.840 | The third one is copied into an internal state w,
00:08:06.460 | and then S1 and w are modular representations that are
00:08:10.140 | added together to form the output representation.
00:08:13.860 | That's a hypothesis about how this model might work.
00:08:16.900 | Now the question is, can we use probing to
00:08:19.100 | reliably assess that hypothesis?
00:08:22.340 | Let's suppose we have this neural network and what we decide is
00:08:25.860 | that L1 probably computes the input z.
00:08:30.060 | Let's suppose we fit a probe model,
00:08:32.180 | it could be a simple identity probe,
00:08:33.940 | and the probe says, yes,
00:08:35.660 | L1 always perfectly encodes the identity of the third input.
00:08:40.580 | Suppose we continue that,
00:08:42.280 | we probe L2 and we find that it always perfectly
00:08:45.540 | computes x plus y according to our very simple probe model.
00:08:50.260 | That might look like evidence for
00:08:52.640 | the hypothesis that we started with.
00:08:54.360 | You say, "Aha, it's a bit counterintuitive because L1
00:08:57.660 | encodes z and L2 x, y,
00:09:00.280 | so it's out of order,
00:09:01.740 | but nonetheless, the model is obeying my hypothesis."
00:09:06.860 | But the probes have misled you.
00:09:09.280 | Here is a look at the full internal structure of this model.
00:09:12.260 | This is all the weight parameters.
00:09:13.860 | Again, this model performs our task perfectly,
00:09:16.780 | but the point is that L2 has
00:09:18.700 | no impact at all on the output behavior.
00:09:21.940 | One way to see that is to look at the output vector of weights,
00:09:25.540 | L2 is just zeroed out as part of this computation,
00:09:28.860 | no causal impact.
00:09:30.260 | The probe said it stored x plus y,
00:09:33.180 | and it might be doing that.
00:09:34.500 | In fact, it is doing that,
00:09:35.940 | but not in a way that tells us about the input-output behavior.
00:09:39.980 | The probe in that deep way,
00:09:42.140 | in that causal way, misled us.
00:09:45.020 | The final goalposts that I set up was,
00:09:48.180 | do we have a path to improving
00:09:49.940 | models from the analysis method that we've chosen?
00:09:53.300 | Here I have a mixed answer.
00:09:55.100 | There does seem to be a path from
00:09:56.660 | probing to what you might call multi-task training,
00:09:59.340 | where I'm training this model to do addition,
00:10:01.820 | and in addition, I train it so that this representation
00:10:05.220 | here encodes z and this one encodes x plus y.
00:10:08.660 | We can certainly have such objectives.
00:10:11.100 | I think it's an open question whether or not it actually
00:10:13.820 | induces the modularity that we're interested in.
00:10:17.820 | But the really deep concern for me is just that
00:10:20.460 | still here we don't get causal guarantees.
00:10:23.260 | We can do the multi-task training,
00:10:24.780 | but that does not guarantee that
00:10:27.060 | the structure we induced, whatever it's like,
00:10:29.500 | is actually shaping performance on the core task,
00:10:32.620 | in this case of adding numbers.
00:10:34.580 | We have to proceed with caution.
00:10:36.980 | Finally, a quick note, I mentioned unsupervised probes.
00:10:40.380 | There's wonderful work in this space
00:10:42.260 | using a variety of different methods.
00:10:44.540 | Here are some references to really
00:10:46.420 | formative entries into that literature.
00:10:49.060 | Again, I think these techniques do not
00:10:51.060 | suffer from the concerns about probe power,
00:10:53.340 | because they don't have their own parameters
00:10:55.980 | typically, but they do,
00:10:58.540 | I think, suffer that limitation about causal inference.
00:11:02.420 | Let's wrap up with our scorecard.
00:11:04.980 | Remember, probing can characterize
00:11:06.980 | representations really well.
00:11:08.420 | We use the supervised probe for that.
00:11:10.340 | That's a smiley face.
00:11:11.820 | But probes cannot offer causal inferences.
00:11:15.100 | I put a thinking emoji under
00:11:17.780 | improved models because it's unclear to me whether
00:11:20.540 | multi-task training is really a viable general way of
00:11:24.300 | moving from probes to better models.
00:11:27.780 | [BLANK_AUDIO]