back to index

Stanford XCS224U I Analysis NLU, Pt 4: Casual Abstraction & Interchange Intervention Training (IIT)


0:27 Recipe for causal abstraction
6:0 Interchange intervention accuracy (IIA)
8:3 Findings from causal abstraction
10:34 Connections to the literature
11:2 Summary
11:36 Method
13:45 Findings from IIT

Whisper Transcript | Transcript Only Page

00:00:00.000 | Welcome back everyone.
00:00:06.120 | This is part 4 in our series on analysis methods for NLP.
00:00:09.680 | We've come to our third set of methods, causal abstraction.
00:00:13.920 | I've been heavily involved with developing these methods.
00:00:16.960 | I think they're tremendously exciting because they offer
00:00:19.360 | a real opportunity for
00:00:21.160 | causal concept level explanations
00:00:24.080 | of how our NLP models are behaving.
00:00:26.960 | Let's begin with a recipe for this causal abstraction analysis.
00:00:31.880 | Step 1, you state a hypothesis about
00:00:34.720 | some aspect of your target model's causal structure,
00:00:38.320 | and you could express this as a small computer program.
00:00:42.640 | In step 2, we're going to search for an alignment between
00:00:46.520 | variables in this causal model we've defined
00:00:49.480 | and sets of neurons in the target model.
00:00:52.740 | This is a hypothesis about how the roles for
00:00:55.600 | those variables and sets of neurons align with each other.
00:00:59.880 | To do this analysis,
00:01:02.360 | to assess these alignments,
00:01:03.760 | we perform the fundamental operation of
00:01:06.000 | causal abstraction analysis, the interchange intervention.
00:01:10.200 | Much of this screencast is going to be devoted to giving you
00:01:13.180 | for a feel for how interchange interventions work.
00:01:17.120 | For a running example,
00:01:19.100 | let's return to our simple neural network that
00:01:21.240 | takes in three numbers and adds them together.
00:01:24.360 | We assume that this network is successful at its task,
00:01:27.720 | and the question is,
00:01:28.880 | in human interpretable terms,
00:01:30.960 | how does the network perform this function?
00:01:34.320 | As before, we can hypothesize
00:01:36.960 | a causal model that's given in green here.
00:01:39.720 | The idea behind this causal model is that the network is
00:01:43.260 | adding together the first two inputs to
00:01:45.400 | form an intermediate variable S1,
00:01:48.180 | and then the third input is copied over
00:01:50.740 | into an intermediate variable W,
00:01:52.980 | and S1 and W are the elements that
00:01:55.880 | directly contribute to the output of the model.
00:01:58.920 | That's a hypothesis about what might be
00:02:01.200 | happening with our otherwise opaque neural model,
00:02:04.240 | and the question is,
00:02:05.420 | is the hypothesis correct?
00:02:07.600 | We're going to use interchange interventions to
00:02:10.120 | help us assess that hypothesis.
00:02:12.580 | We'll break this down into a few pieces.
00:02:14.600 | First, we hypothesize that
00:02:16.760 | the neural representation L3 plays the same role as S1.
00:02:22.360 | Let's assess that idea.
00:02:23.960 | The first intervention happens on the causal model.
00:02:27.280 | We take our causal model and we process
00:02:29.160 | example 1, 3, 5, and we get 9.
00:02:31.800 | We use that same causal model to process 4,
00:02:34.880 | 5, 6, and we get 15.
00:02:37.640 | Now the intervention comes.
00:02:39.380 | We're going to target the S1 variable
00:02:41.560 | for the right-hand example that has value 9,
00:02:44.180 | literally take that value and place it in
00:02:46.980 | the corresponding place in the left-hand example.
00:02:50.120 | The causal model is completely understood by us,
00:02:53.000 | and so we know exactly what will happen now.
00:02:55.400 | The output will change to 14.
00:02:57.680 | The child nodes below
00:02:59.520 | the variable that we intervened on don't matter in this case.
00:03:02.400 | The intervention fully wipes them out,
00:03:04.480 | and we're just adding 9 and 5 together.
00:03:07.640 | That's the causal model.
00:03:09.320 | We assume that we understand it before we begin the analysis.
00:03:13.160 | The interesting part comes when we think about the neural model.
00:03:16.560 | We don't know how this neural model works,
00:03:18.760 | and we're going to try to use these interventions to uncover that.
00:03:22.200 | We process 1, 3,
00:03:24.120 | 5 with our neural model and we get 9.
00:03:26.800 | We process 4, 5, 6,
00:03:28.800 | and we get 15.
00:03:30.320 | Now we're going to intervene on the L3 state.
00:03:33.480 | We target that in the right-hand example,
00:03:35.960 | and we literally take those values and place
00:03:38.600 | them in the corresponding spot in the left-hand example.
00:03:42.440 | We study the output.
00:03:44.400 | If the output after that intervention is 14,
00:03:48.380 | then we have one piece of evidence that L3 plays the same causal role as S1.
00:03:54.920 | If we repeat this intervention for every conceivable input to these models,
00:04:00.320 | and we always see this alignment between causal model and neural model,
00:04:04.840 | we have proven that L3 plays the same causal role as S1.
00:04:10.160 | We can continue this for other variables.
00:04:12.280 | Let's target now L1.
00:04:13.680 | Suppose we hypothesize that it plays the same role as
00:04:17.160 | W in the causal model.
00:04:19.200 | Again, let's first intervene on the causal model.
00:04:21.680 | We target that W variable on the right-hand.
00:04:24.440 | We take that value and we place it in the corresponding place in the left-hand model.
00:04:29.480 | We study the output that has changed the output to 10.
00:04:33.160 | Then we return to our neural models.
00:04:35.240 | Parallel operation, target L1 on the right,
00:04:38.480 | take that value and literally place it into
00:04:41.200 | the corresponding spot in the left and we study the output.
00:04:44.920 | Again, if the output is 10,
00:04:46.560 | we have a single piece of evidence that L1 and W are causally aligned in this way.
00:04:52.400 | If we repeat this intervention for every possible input and always see this correspondence,
00:04:57.500 | we have proven that L1 and W play the same causal roles.
00:05:02.480 | We could go one step further.
00:05:04.180 | Suppose we think about L2.
00:05:05.940 | Suppose we intervene on L2 in every way we can think of,
00:05:10.060 | and we never see an impact on the output behavior of the model.
00:05:14.220 | In that way, we have proven that L2 plays
00:05:17.160 | no causal role in the input-output behavior of this network.
00:05:21.000 | Since we can assume that the input variables are aligned across causal and neural models,
00:05:25.960 | and we can assume that the output variables are aligned,
00:05:28.700 | we have now fully proven via all these intervention experiments that
00:05:33.240 | that causal model in green is an abstraction
00:05:36.000 | of the otherwise more complex neural model.
00:05:39.440 | That is exciting. If we have actually established this,
00:05:42.400 | then we are licensed to allow the neural model to fall away,
00:05:46.120 | and we can reason entirely in terms of the causal model,
00:05:49.540 | secure that the two models are causally aligned.
00:05:53.580 | They have the same underlying mechanisms.
00:05:57.080 | Now, that is a ideal of causal abstraction analysis.
00:06:02.060 | There are a few things from the real world that are going to intervene.
00:06:05.540 | The first is that we can never perform the full set of interventions.
00:06:09.480 | For all realistic cases,
00:06:10.920 | there are too many inputs.
00:06:12.440 | Even for the case of my tiny addition network,
00:06:15.120 | there is an infinitude of possible inputs,
00:06:18.260 | we can't check them all.
00:06:19.440 | We have to pick a small subset of examples.
00:06:23.320 | Then otherwise, for real models,
00:06:25.380 | we're never going to see perfect causal abstraction relationships because of
00:06:29.360 | the messy nature of naturally trained models that we use.
00:06:34.280 | We need some graded notion of success,
00:06:36.720 | and I think interchange intervention accuracy
00:06:39.200 | is a good initial baseline metric for that.
00:06:42.720 | The IIA is the percentage of interchange interventions that you
00:06:46.820 | performed that lead to outputs that match
00:06:49.680 | those of the causal model under the chosen alignment.
00:06:52.560 | You can think of it as an accuracy measure for your hypothesized alignment.
00:06:58.280 | IIA is scaled in 0,
00:07:00.720 | 1 as with a normal accuracy metric.
00:07:03.440 | It can actually be above task performance.
00:07:06.720 | This is striking,
00:07:08.160 | and it has happened to us in practice.
00:07:10.380 | If the interchange interventions put the model
00:07:12.880 | into a better state than it was in originally,
00:07:15.280 | then you might actually see a boost in
00:07:17.560 | performance from these Frankenstein examples that you have created.
00:07:21.720 | This is really fundamental here.
00:07:24.040 | IIA is extremely sensitive to
00:07:26.360 | the set of interchange interventions that you decided to perform.
00:07:29.440 | If you can't perform all of them,
00:07:30.900 | you have to pick a subset,
00:07:32.340 | and that will be a factor in shaping your accuracy results.
00:07:36.920 | In particular, pay particular attention to
00:07:40.040 | how many interchange interventions should change the output label.
00:07:43.880 | Those are the ones that are really providing
00:07:45.920 | causal insights because you see exactly what should
00:07:48.840 | happen in terms of changes once you have performed the intervention.
00:07:53.800 | Having an abundance of these causally insightful interventions is
00:07:58.160 | the most powerful thing you can do in terms of building an argument.
00:08:02.960 | Let me briefly summarize some findings from causal abstraction.
00:08:07.200 | These are mostly from our work.
00:08:09.500 | Fine-tuned BERT models succeed at hard out-of-domain examples
00:08:13.880 | involving lexical entailment and negation
00:08:16.000 | because they are abstracted by simple monotonicity programs.
00:08:19.720 | I emphasize because,
00:08:21.240 | and I wrote it in blue there because I am not
00:08:23.440 | being casual with that causal language.
00:08:25.840 | I really intend a causal claim.
00:08:28.280 | That is the kind of thing that
00:08:30.160 | causal abstraction licenses you to be able to say.
00:08:34.080 | Relatedly, fine-tuned BERT models succeed at
00:08:37.760 | the MQNLI task because they find compositional solutions.
00:08:42.240 | MQNLI is the multiply quantified NLI benchmark.
00:08:45.880 | It's a synthetic benchmark full of
00:08:47.800 | very intricate compositional analyses
00:08:50.640 | between quantifiers and modifiers and so forth.
00:08:53.840 | A challenging benchmark,
00:08:55.360 | and we show with causal abstraction that models
00:08:57.920 | succeed to the extent that they actually
00:09:00.280 | find compositional solutions to the task.
00:09:03.760 | Models succeed at the MNIST pointer value retrieval task
00:09:07.880 | because they are abstracted by simple programs like,
00:09:11.060 | if the digit is six,
00:09:12.300 | then the label is in the lower left.
00:09:14.720 | A brief digression there,
00:09:16.700 | I love these explanations.
00:09:19.040 | That simple program that I described is
00:09:21.400 | more or less a description of the task.
00:09:23.720 | It's wonderfully reassuring to see that
00:09:26.040 | our explanations actually align with
00:09:28.640 | the task structure for these very successful models.
00:09:31.920 | Another nice point here is that we're starting to see
00:09:34.840 | a blurring of the distinction between
00:09:37.360 | neural models and symbolic models.
00:09:39.800 | After all, if you can show that the two are
00:09:42.120 | aligned via causal abstraction,
00:09:44.400 | then there really is no meaningful difference between the two,
00:09:48.040 | which leads you to wonder whether there's truly
00:09:50.480 | a meaningful difference between symbolic AI and neural AI.
00:09:54.860 | They can certainly come together and you see
00:09:57.400 | them coming together in these analyses.
00:10:00.620 | Finally, Bart and T5 use
00:10:03.240 | coherent entity and situation representations
00:10:05.800 | that evolve as the discourse unfolds.
00:10:08.400 | Liatal 2021 use causal abstraction
00:10:11.300 | in order to substantiate that claim.
00:10:14.220 | Very exciting to see.
00:10:16.440 | If you would like to get hands-on with these ideas,
00:10:19.600 | I would encourage you to check out our notebook.
00:10:21.760 | It's called IIT Equality.
00:10:23.560 | It walks through causal abstraction analysis
00:10:25.840 | using simple toy examples,
00:10:28.240 | and then also shows you how to apply IIT,
00:10:31.000 | which is the next topic we'll discuss.
00:10:34.200 | There isn't time to cover this in detail,
00:10:36.760 | but I did want to call out that causal abstraction is
00:10:39.980 | a toolkit corresponding to a large family
00:10:43.360 | of intervention-based methods for understanding our models.
00:10:47.140 | I've listed a few other exciting entries in this literature here.
00:10:51.200 | If you would like even more connections to the literature,
00:10:54.200 | I recommend this blog post that we did,
00:10:56.440 | which relates a lot of
00:10:57.840 | these methods to causal abstraction itself.
00:11:01.600 | Let's return to our summary scorecard.
00:11:04.940 | We're talking about intervention-based methods.
00:11:07.600 | I claim that they can characterize representations richly.
00:11:10.920 | After all, we show how those representations correspond
00:11:13.840 | to interpretable high-level variables.
00:11:16.920 | I've also tried to argue that this is
00:11:18.920 | a causal inference method,
00:11:21.080 | and I still have a smiley under improved models.
00:11:23.960 | I have not substantiated that for you next,
00:11:26.720 | but that is the next task under
00:11:28.760 | the heading of interchange intervention training.
00:11:31.860 | Let's turn to that now, IIT.
00:11:35.360 | The method is quite simple and builds
00:11:38.360 | directly on causal abstraction with interchange interventions.
00:11:42.120 | Here's a summary diagram of
00:11:44.840 | interchange intervention using our addition example,
00:11:48.120 | with the one twist that you'll notice that
00:11:50.800 | my intervention now for L3 has led to an incorrect result.
00:11:55.140 | We wanted 14 and we got four.
00:11:58.720 | We have in some sense shown that
00:12:01.360 | our hypothesized alignment between
00:12:03.520 | these variables is not correct.
00:12:05.640 | But I think you can also see in here an opportunity to do better.
00:12:09.680 | We can correct this misalignment if we want to.
00:12:12.680 | After all, we know what the label
00:12:15.080 | should have been and we know what it was.
00:12:18.160 | That gives us a gradient signal that we can
00:12:21.280 | use to update the parameters of this model and make it
00:12:24.920 | more conform to our underlying causal model under this alignment.
00:12:29.860 | Let's see how that would play out.
00:12:31.360 | We get our error signal and that flows back as
00:12:33.840 | usual to the hidden states L1, L2, and L3.
00:12:38.080 | For L1, the gradients flow back as usual to the input states.
00:12:43.080 | The same thing is true for L2.
00:12:45.480 | But for L3, we have a more complicated update.
00:12:48.780 | We have literally copied over the full computation graph in
00:12:52.800 | the PyTorch sense including all the gradient information.
00:12:56.720 | What we get for L3 is a double update coming from
00:13:00.560 | our current example as well as the source example,
00:13:03.800 | which also processed that representation.
00:13:06.520 | We get a double update.
00:13:08.260 | The result of repeatedly performing
00:13:11.360 | these IIT updates on these models using
00:13:14.160 | the causal model for the labels as we've done here,
00:13:17.340 | is that we push the model to modularize information
00:13:21.000 | about S1 in this case in the L3 variable.
00:13:25.000 | The importance of alignments falls away and the emphasis
00:13:29.200 | here is on actually pushing models,
00:13:31.120 | improving them by making them
00:13:33.360 | have the causal structure that we have
00:13:35.440 | hypothesized in the hopes that they will then perform in
00:13:38.600 | more systematic ways and be better at the tasks we've set for them.
00:13:43.640 | Findings from IIT.
00:13:46.960 | We showed that IIT achieve state-of-the-art results on
00:13:51.440 | that MNIST pointer value retrieval task that I mentioned before,
00:13:54.560 | as well as ReScan,
00:13:55.760 | which is a grounded language understanding benchmark.
00:13:58.660 | We also showed that IIT can be used as a distillation objective,
00:14:04.000 | where essentially what we do is distill
00:14:06.640 | teacher models into student models,
00:14:08.720 | forcing them not only to conform in their input-output behavior,
00:14:12.280 | but also conform at the level of
00:14:14.680 | their internal representations under
00:14:17.060 | the counterfactuals that we create for IIT.
00:14:19.900 | This is exciting to me because I think it's
00:14:21.840 | a powerful distillation method and it also shows you that
00:14:25.560 | the causal model that we use for IIT can be quite abstract.
00:14:29.480 | In this case, it's just a high-level constraint on what we
00:14:32.780 | want the teacher and student models to look like.
00:14:36.600 | We also showed that IIT can be used to induce
00:14:39.760 | internal representations of characters in
00:14:42.280 | language models that are based in subword tokenization.
00:14:45.760 | We showed that this helps with a variety
00:14:47.760 | of character level games and tasks.
00:14:50.060 | This is IIT being used to strike a balance.
00:14:52.860 | Subword models seem to be our best language models,
00:14:56.240 | but we have tasks that require knowledge of characters.
00:15:00.000 | What we do with IIT is imbue these models with
00:15:02.640 | knowledge of characters in their internal states.
00:15:06.560 | Finally, we recently used IIT to create
00:15:09.760 | concept level methods for explaining model behavior.
00:15:13.280 | That's a technique that we call causal proxy models,
00:15:16.160 | and it essentially leverages the core insight of IIT.
00:15:20.960 | Again, we have this course notebook, IIT equality.
00:15:24.840 | It covers abstraction analyses and then also shows
00:15:27.960 | you how to train models in this IIT mode.
00:15:31.920 | We can return to our scorecard.
00:15:34.320 | Now I have smileys across the board,
00:15:36.640 | and I claim that I have justified all of those smileys.
00:15:39.840 | I feel that this does point to intervention-based methods as
00:15:43.280 | the best bet we have for
00:15:45.200 | deeply understanding how NLP models work.
00:15:49.840 | [BLANK_AUDIO]