back to index

Stanford XCS224U: NLU I Analysis Methods for NLU, Part 3: Feature Attribution I Spring 2023


Chapters

0:0
3:34 Attributions wrt predicted vs. actual labels
5:22 Gradients inputs fails sensitivity
6:32 Integrated gradients: Intuition
7:27 Integrated gradients: Core computation
8:29 Sensitivity again
9:22 BERT example
12:11 A small challenge test
13:25 Summary

Whisper Transcript | Transcript Only Page

00:00:00.000 | Welcome back everyone.
00:00:06.200 | This is part three in our series on analysis methods for NLP.
00:00:09.920 | We're going to be focused on feature attribution methods.
00:00:13.260 | I should say at the start that to keep things manageable,
00:00:16.400 | we're going to mainly focus on
00:00:18.080 | integrated gradients from Sundararajan et al 2017.
00:00:21.600 | This is a shining,
00:00:23.240 | powerful, inspiring example of
00:00:25.360 | an attribution method for reasons I will discuss.
00:00:28.340 | But it's by no means the only method in this space.
00:00:31.680 | For one-stop shopping on these methods,
00:00:34.120 | I recommend the captum.ai library.
00:00:36.680 | It will give you access to lots of gradient-based methods like IG,
00:00:41.280 | as well as many others,
00:00:42.740 | including more traditional methods like
00:00:44.900 | feature ablation and feature permutation.
00:00:47.920 | Check out captum.ai.
00:00:49.660 | In addition, if you would like a deeper dive on
00:00:52.460 | the calculations and examples that I use in this screencast,
00:00:55.960 | I recommend the notebook feature attribution,
00:00:58.340 | which is part of the course code repository.
00:01:01.900 | Now, I love the integrated gradients paper,
00:01:06.140 | Sundararajan et al 2017,
00:01:08.220 | because of its method,
00:01:09.620 | but also because it offers
00:01:11.580 | a really nice framework for thinking about attribution in general,
00:01:15.820 | and they do that in the form of three axioms.
00:01:18.460 | I'm going to talk about two of them.
00:01:20.340 | Of the two, the most important one is sensitivity.
00:01:23.660 | This is very intuitive.
00:01:25.420 | The axiom of sensitivity for attribution methods says,
00:01:28.680 | if two inputs, x and x prime differ only at dimension i,
00:01:32.980 | and lead to different predictions,
00:01:35.500 | then the feature associated with dimension i has non-zero attribution.
00:01:41.380 | Here's a quick example. Our model is M,
00:01:44.460 | and it takes inputs that are three-dimensional,
00:01:47.180 | and for input 1, 0, 1,
00:01:49.220 | this model outputs positive,
00:01:51.100 | and for 1, 1, 1,
00:01:52.980 | it outputs negative.
00:01:54.580 | That's a difference in the predictions,
00:01:56.900 | and that means that the feature in position 2 here must have non-zero attribution.
00:02:02.980 | Seems very intuitive because obviously this feature
00:02:06.020 | is important to the behavior of this model.
00:02:08.740 | Just quickly, I'll mention a second axiom, implementation invariance.
00:02:12.780 | If two models, M and M prime,
00:02:14.660 | have identical input-output behavior,
00:02:17.260 | then the attributions for M and M prime are identical.
00:02:20.660 | That's very intuitive.
00:02:21.820 | If the models can't be distinguished behaviorally,
00:02:24.500 | then we should give them identical attributions.
00:02:27.060 | We should not be sensitive to incidental details
00:02:30.420 | of how they were structured or how they were implemented.
00:02:34.220 | Let's begin with a baseline method, gradients by inputs.
00:02:39.260 | This is very intuitive and makes some sense from
00:02:42.300 | the perspective of doing feature attribution in deep networks.
00:02:46.220 | What we're going to do is calculate the gradients for
00:02:48.660 | our model with respect to the chosen feature that we want to
00:02:51.820 | target and multiply that value by the actual value of the feature.
00:02:56.460 | Gradients by inputs.
00:02:58.300 | It's called gradients by inputs,
00:03:00.060 | but obviously we could do this gradient taking with respect to
00:03:03.540 | any neuron in one of these deep learning models and
00:03:06.300 | multiply it by the actual value of that neuron for some example.
00:03:09.940 | Actually, this method generalizes really nicely
00:03:12.860 | to any state in a deep learning model.
00:03:15.840 | It's really straightforward to implement that.
00:03:18.060 | I've depicted that on the slide here.
00:03:19.740 | The first implementation uses raw PyTorch,
00:03:23.380 | and the second one is just a lightweight wrapper
00:03:26.100 | around the CAPTEM implementation of input by gradient.
00:03:30.340 | Shows you how straightforward this can be.
00:03:33.380 | One issue that I want to linger over here that I find
00:03:36.900 | conceptually difficult is this question
00:03:39.700 | of how we should do the attributions.
00:03:42.060 | For classifier models, we have a choice.
00:03:44.380 | We can take attributions with respect to
00:03:46.340 | the predicted labels or with respect to the actual labels,
00:03:49.900 | which are two different dimensions in
00:03:52.220 | the output vector for these models.
00:03:55.220 | Now, if the model you're studying is very high-performing,
00:03:59.260 | then the predicted and actual labels will be
00:04:01.420 | almost identical and this is unlikely to matter.
00:04:04.140 | But you might be trying to study
00:04:06.260 | a really poor performing model to try to
00:04:08.420 | understand where its deficiencies lie,
00:04:11.100 | and that's precisely the case where these two will come apart.
00:04:14.280 | As an illustration on this slide here,
00:04:16.220 | I've defined a simple make classification
00:04:18.580 | synthetic problem using scikit-learn.
00:04:20.900 | It has four features.
00:04:22.700 | Then I set up a shallow neural classifier
00:04:25.260 | and I deliberately under-trained it.
00:04:27.180 | It has just one training iteration.
00:04:29.540 | This is a very bad model.
00:04:31.940 | If I do attributions with respect to the true labels,
00:04:35.900 | I get one vector of attribution scores.
00:04:39.300 | If I do attributions with respect to the predicted labels,
00:04:42.980 | I get a totally different set of attribution scores.
00:04:46.980 | That confronts you with
00:04:49.340 | a difficult conceptual question of
00:04:51.260 | which ones you want to use to guide your analysis.
00:04:54.260 | They're giving us different pictures of this model.
00:04:57.580 | I think that there is no a priori reason
00:05:00.220 | to favor one over the other.
00:05:01.860 | I think it really comes down to what you're trying to
00:05:04.740 | accomplish with the analysis that you're constructing.
00:05:07.740 | The best answer I can give is to be
00:05:10.220 | explicit about your assumptions
00:05:12.700 | and about the methods that you used.
00:05:14.820 | This issue, by the way, will carry forward through
00:05:17.300 | all of these gradient-based methods,
00:05:19.100 | not just inputs by gradients.
00:05:22.420 | Here's the fundamental sticking point for gradients by inputs.
00:05:27.220 | It fails that sensitivity axiom.
00:05:29.580 | This is an example,
00:05:30.860 | a counterexample to sensitivity that comes
00:05:32.940 | directly from Sundararajan et al, 2017.
00:05:36.100 | We have a very simple model here, M.
00:05:38.660 | It takes one-dimensional inputs,
00:05:41.060 | and what it does is one minus
00:05:43.500 | the ReLU activation applied to one minus the input.
00:05:47.060 | Very simple model.
00:05:48.780 | When we use the model with input 0,
00:05:51.500 | we get 0 as the output.
00:05:53.620 | When we use the model with input 2,
00:05:56.140 | we get 1 as the output.
00:05:57.820 | We have a difference in output predictions.
00:06:00.940 | These are one-dimensional inputs,
00:06:02.780 | so we are now required by sensitivity to give
00:06:05.380 | non-zero attribution to this feature.
00:06:08.940 | But sadly, we do not.
00:06:10.740 | When you run input by gradients on this input,
00:06:13.220 | you get 0, and when you run input by gradients on input 2,
00:06:17.260 | you also get 0,
00:06:18.580 | and that is just a direct failure
00:06:20.580 | to meet the sensitivity requirement.
00:06:23.020 | That's a worrisome thing about this baseline method.
00:06:26.540 | It queues us up nicely to think about
00:06:28.740 | how integrated gradients will do better.
00:06:31.740 | The intuition behind integrated gradients is that we are going
00:06:36.380 | to explore counterfactual versions of our input,
00:06:39.780 | and I think that is an important insight.
00:06:42.060 | As we try to get causal insights into model behavior,
00:06:45.700 | it becomes ever more essential to think about counterfactuals.
00:06:49.780 | Here's the way IG does this.
00:06:52.180 | We have two features in our space,
00:06:54.340 | X_1 and X_2, and this blue dot represents
00:06:57.020 | the example that we would like to do attributions for.
00:07:00.700 | With integrated gradients, the first thing we do is set up
00:07:03.380 | a baseline and a standard baseline
00:07:05.780 | for this would be the zero vector.
00:07:08.140 | Then we are going to create synthetic examples
00:07:11.180 | interpolated between the baseline and our actual example.
00:07:16.060 | We are going to study the gradients with respect to
00:07:18.460 | every single one of these interpolated examples,
00:07:21.820 | aggregate them together, and use all of
00:07:24.100 | that information to do our attributions.
00:07:27.500 | Here's a look at the IG calculation in some detail as you
00:07:31.500 | might implement it for an actual piece of software.
00:07:35.500 | Let's break this down into some pieces.
00:07:37.380 | Step 1, we have this vector alpha,
00:07:40.180 | and this is going to determine the number of steps that we use to
00:07:43.420 | get different synthetic inputs between
00:07:45.460 | baseline and the actual example.
00:07:48.500 | We're going to interpolate these inputs between
00:07:51.260 | the baseline and the actual example.
00:07:53.300 | That's what happens in purple here,
00:07:55.380 | according to these alpha steps.
00:07:57.820 | Then we compute the gradients for each interpolated input,
00:08:01.220 | and that part of the calculation,
00:08:02.780 | the individual ones, looks exactly
00:08:04.820 | like inputs by gradients, of course.
00:08:07.620 | Next step, we do an integral approximation through the averaging,
00:08:11.740 | that's the summation that you see.
00:08:13.380 | We're going to sum over all of
00:08:14.980 | these examples that we've taken and created the gradients for.
00:08:19.060 | Then finally, we do some scaling to remain in
00:08:21.620 | the space region of the original example.
00:08:25.060 | That is the complete IG calculation.
00:08:28.780 | Let's return to sensitivity.
00:08:31.220 | We have our model M with these one-dimensional inputs,
00:08:35.620 | one minus relu applied to one minus x.
00:08:38.900 | This is the example from Sundararajan et al.
00:08:41.620 | I showed you before that inputs by gradients
00:08:44.420 | fail sensitivity for this example in this model.
00:08:48.260 | Integrated gradients does better.
00:08:51.300 | The reason it does better,
00:08:52.620 | you can see this here,
00:08:53.580 | we are summing over all of
00:08:55.980 | those gradient calculations and averaging through them.
00:08:59.580 | The result of all of that summing and averaging is
00:09:03.340 | an attribution of approximately one depending on
00:09:06.140 | exactly which steps that you decide to
00:09:07.980 | look at for the IG calculation.
00:09:10.980 | This example is no longer a counter example to sensitivity.
00:09:16.060 | In fact, it's provable that IG satisfies the sensitivity axiom.
00:09:21.460 | Let's think in practical terms now.
00:09:24.260 | We're likely to be thinking about BERT style models.
00:09:27.380 | This is a cartoon version of BERT where I
00:09:29.580 | have some output labels up here at the top.
00:09:32.180 | I have a lot of hidden states and I have a lot of things
00:09:35.140 | happening all the way down to
00:09:36.820 | maybe multiple and fixed embedding layers.
00:09:39.660 | The fundamental thing about IG that makes it so freeing is that
00:09:43.420 | we can do attributions with respect to
00:09:45.620 | any neuron in any state in this entire model.
00:09:49.780 | We have some of the flexibility of probing,
00:09:52.380 | but now we will get causal guarantees that
00:09:54.820 | our attributions relate to
00:09:56.380 | the causal efficacy of neurons on input-output behavior.
00:10:01.060 | Here's a complete worked example
00:10:03.600 | that you might want to work with yourselves,
00:10:05.800 | modify, study, and so forth.
00:10:07.420 | Let me walk through it at a high level for now.
00:10:10.420 | The first thing I do is load in
00:10:12.980 | Twitter-based sentiment classifier based in
00:10:16.540 | Roberta that I got from Hugging Face.
00:10:19.420 | For the sake of CAPTEM,
00:10:20.860 | you'd need to define your own probabilistic prediction function,
00:10:24.460 | and you need to define a function that will create for you
00:10:27.100 | representations for your base as well as for your actual example.
00:10:31.100 | Those are the things that will interpolate between with IG.
00:10:35.140 | You need to do one more function to find
00:10:37.660 | the forward part of your model.
00:10:39.820 | Here I just needed to grab the logits,
00:10:42.140 | and then IG forward and whatever layer I pick are
00:10:45.700 | the core arguments to layer integrated gradients for CAPTEM.
00:10:49.900 | Here's my example.
00:10:51.480 | This is illuminating.
00:10:52.420 | It has true class positive.
00:10:54.620 | I'll take attributions with respect to the positive class, the true class.
00:10:58.860 | Here are my base and actual inputs,
00:11:02.680 | and here's the actual attribution step.
00:11:05.460 | Inputs, base IDs, the target is the true class,
00:11:10.020 | and this is some other keyword argument.
00:11:12.180 | The result of that is some scores which I use,
00:11:15.200 | and then I z-score normalize them
00:11:17.340 | across individual representations in the BERT model.
00:11:20.700 | That's an additional assumption that I'm bringing in that will do
00:11:23.620 | averaging of attributions across these hidden representations.
00:11:28.780 | Little more calculating, and then CAPTEM provides
00:11:31.780 | a nice function for visualizing these things.
00:11:34.380 | What you get out are little snippets of HTML that
00:11:37.500 | summarize the attributions and associated metadata.
00:11:40.660 | I've got the true label,
00:11:41.940 | the predicted label, those align.
00:11:44.820 | There's some scores, and I'm not sure
00:11:47.020 | actually what attribution label is supposed to do.
00:11:49.500 | But the nice thing is that we have color coding here with
00:11:52.300 | color proportional to the attribution score.
00:11:55.980 | You have to be a little bit careful here.
00:11:58.100 | Green means evidence toward the positive label,
00:12:02.060 | which in sentiment might be negative.
00:12:03.920 | This is meant to be the true label,
00:12:05.640 | and negative is evidence away from the true label,
00:12:08.520 | and the white there is neutral.
00:12:11.100 | Here's a fuller example, and for this one to avoid confusing myself,
00:12:14.720 | I did relabel the legend away from the true label,
00:12:17.900 | neutral with respect to the true label,
00:12:19.740 | and toward the true label.
00:12:21.380 | This is very intuitive.
00:12:23.020 | Where the true label is positive,
00:12:24.720 | we get strong attributions for great.
00:12:27.340 | Where the true label is negative,
00:12:29.300 | we get strong attributions for words like
00:12:31.540 | wrong and less activation for things like great.
00:12:34.740 | This is intuitive here that we're getting
00:12:36.860 | more activation for said than for great,
00:12:40.180 | suggesting the model has learned that
00:12:42.340 | the reporting verb there modulates
00:12:45.080 | the positive sentiment that is in its scope.
00:12:48.200 | Then down at the bottom here,
00:12:50.080 | we have one of these tricky situations.
00:12:52.600 | The true label is zero,
00:12:54.240 | the predicted label is one.
00:12:56.140 | These are attributions with respect to the true label.
00:12:59.440 | We're seeing, for example,
00:13:00.840 | that incorrect is biased toward negative.
00:13:04.440 | My guess about these attributions is that the model is actually doing
00:13:07.880 | pretty well with this example and
00:13:09.980 | maybe missed for some incidental reason.
00:13:12.880 | But overall, I would say qualitatively,
00:13:15.940 | this slide is a reassuring picture that the model is doing
00:13:19.500 | something systematic with its features
00:13:22.300 | in making these sentiment predictions.
00:13:25.380 | To summarize, feature attribution can
00:13:29.600 | give okay characterizations of the representations.
00:13:32.900 | You really just get a scalar value about the degree of
00:13:35.920 | importance and you have to make further guesses about
00:13:39.100 | why representations are important,
00:13:43.020 | but it's still useful guidance.
00:13:45.580 | We can get causal guarantees when we use models like IIG.
00:13:50.220 | But I'm afraid that there's no clear direct path to
00:13:53.620 | using methods like IG to directly improve models.
00:13:57.260 | It's like you've just got
00:13:58.820 | some ambient information that might guide you in
00:14:02.420 | a subsequent and separate modeling step
00:14:04.880 | that would improve your model.
00:14:06.820 | That's a summary of feature attribution,
00:14:08.740 | a powerful, pretty flexible heuristic method that can
00:14:12.500 | offer useful insights about how models are solving tasks.
00:14:17.700 | [BLANK_AUDIO]