back to indexStanford XCS224U I Analysis NLU, Pt 4: Casual Abstraction & Interchange Intervention Training (IIT)
Chapters
0:0
0:27 Recipe for causal abstraction
6:0 Interchange intervention accuracy (IIA)
8:3 Findings from causal abstraction
10:34 Connections to the literature
11:2 Summary
11:36 Method
13:45 Findings from IIT
00:00:06.120 |
This is part 4 in our series on analysis methods for NLP. 00:00:09.680 |
We've come to our third set of methods, causal abstraction. 00:00:13.920 |
I've been heavily involved with developing these methods. 00:00:16.960 |
I think they're tremendously exciting because they offer 00:00:26.960 |
Let's begin with a recipe for this causal abstraction analysis. 00:00:34.720 |
some aspect of your target model's causal structure, 00:00:38.320 |
and you could express this as a small computer program. 00:00:42.640 |
In step 2, we're going to search for an alignment between 00:00:55.600 |
those variables and sets of neurons align with each other. 00:01:06.000 |
causal abstraction analysis, the interchange intervention. 00:01:10.200 |
Much of this screencast is going to be devoted to giving you 00:01:13.180 |
for a feel for how interchange interventions work. 00:01:19.100 |
let's return to our simple neural network that 00:01:21.240 |
takes in three numbers and adds them together. 00:01:24.360 |
We assume that this network is successful at its task, 00:01:39.720 |
The idea behind this causal model is that the network is 00:01:55.880 |
directly contribute to the output of the model. 00:02:01.200 |
happening with our otherwise opaque neural model, 00:02:07.600 |
We're going to use interchange interventions to 00:02:16.760 |
the neural representation L3 plays the same role as S1. 00:02:23.960 |
The first intervention happens on the causal model. 00:02:46.980 |
the corresponding place in the left-hand example. 00:02:50.120 |
The causal model is completely understood by us, 00:02:59.520 |
the variable that we intervened on don't matter in this case. 00:03:09.320 |
We assume that we understand it before we begin the analysis. 00:03:13.160 |
The interesting part comes when we think about the neural model. 00:03:18.760 |
and we're going to try to use these interventions to uncover that. 00:03:30.320 |
Now we're going to intervene on the L3 state. 00:03:38.600 |
them in the corresponding spot in the left-hand example. 00:03:48.380 |
then we have one piece of evidence that L3 plays the same causal role as S1. 00:03:54.920 |
If we repeat this intervention for every conceivable input to these models, 00:04:00.320 |
and we always see this alignment between causal model and neural model, 00:04:04.840 |
we have proven that L3 plays the same causal role as S1. 00:04:13.680 |
Suppose we hypothesize that it plays the same role as 00:04:19.200 |
Again, let's first intervene on the causal model. 00:04:24.440 |
We take that value and we place it in the corresponding place in the left-hand model. 00:04:29.480 |
We study the output that has changed the output to 10. 00:04:41.200 |
the corresponding spot in the left and we study the output. 00:04:46.560 |
we have a single piece of evidence that L1 and W are causally aligned in this way. 00:04:52.400 |
If we repeat this intervention for every possible input and always see this correspondence, 00:04:57.500 |
we have proven that L1 and W play the same causal roles. 00:05:05.940 |
Suppose we intervene on L2 in every way we can think of, 00:05:10.060 |
and we never see an impact on the output behavior of the model. 00:05:17.160 |
no causal role in the input-output behavior of this network. 00:05:21.000 |
Since we can assume that the input variables are aligned across causal and neural models, 00:05:25.960 |
and we can assume that the output variables are aligned, 00:05:28.700 |
we have now fully proven via all these intervention experiments that 00:05:39.440 |
That is exciting. If we have actually established this, 00:05:42.400 |
then we are licensed to allow the neural model to fall away, 00:05:46.120 |
and we can reason entirely in terms of the causal model, 00:05:49.540 |
secure that the two models are causally aligned. 00:05:57.080 |
Now, that is a ideal of causal abstraction analysis. 00:06:02.060 |
There are a few things from the real world that are going to intervene. 00:06:05.540 |
The first is that we can never perform the full set of interventions. 00:06:12.440 |
Even for the case of my tiny addition network, 00:06:25.380 |
we're never going to see perfect causal abstraction relationships because of 00:06:29.360 |
the messy nature of naturally trained models that we use. 00:06:36.720 |
and I think interchange intervention accuracy 00:06:42.720 |
The IIA is the percentage of interchange interventions that you 00:06:49.680 |
those of the causal model under the chosen alignment. 00:06:52.560 |
You can think of it as an accuracy measure for your hypothesized alignment. 00:07:10.380 |
If the interchange interventions put the model 00:07:12.880 |
into a better state than it was in originally, 00:07:17.560 |
performance from these Frankenstein examples that you have created. 00:07:26.360 |
the set of interchange interventions that you decided to perform. 00:07:32.340 |
and that will be a factor in shaping your accuracy results. 00:07:40.040 |
how many interchange interventions should change the output label. 00:07:45.920 |
causal insights because you see exactly what should 00:07:48.840 |
happen in terms of changes once you have performed the intervention. 00:07:53.800 |
Having an abundance of these causally insightful interventions is 00:07:58.160 |
the most powerful thing you can do in terms of building an argument. 00:08:02.960 |
Let me briefly summarize some findings from causal abstraction. 00:08:09.500 |
Fine-tuned BERT models succeed at hard out-of-domain examples 00:08:16.000 |
because they are abstracted by simple monotonicity programs. 00:08:21.240 |
and I wrote it in blue there because I am not 00:08:30.160 |
causal abstraction licenses you to be able to say. 00:08:37.760 |
the MQNLI task because they find compositional solutions. 00:08:42.240 |
MQNLI is the multiply quantified NLI benchmark. 00:08:50.640 |
between quantifiers and modifiers and so forth. 00:08:55.360 |
and we show with causal abstraction that models 00:09:03.760 |
Models succeed at the MNIST pointer value retrieval task 00:09:07.880 |
because they are abstracted by simple programs like, 00:09:28.640 |
the task structure for these very successful models. 00:09:31.920 |
Another nice point here is that we're starting to see 00:09:44.400 |
then there really is no meaningful difference between the two, 00:09:48.040 |
which leads you to wonder whether there's truly 00:09:50.480 |
a meaningful difference between symbolic AI and neural AI. 00:10:03.240 |
coherent entity and situation representations 00:10:16.440 |
If you would like to get hands-on with these ideas, 00:10:19.600 |
I would encourage you to check out our notebook. 00:10:36.760 |
but I did want to call out that causal abstraction is 00:10:43.360 |
of intervention-based methods for understanding our models. 00:10:47.140 |
I've listed a few other exciting entries in this literature here. 00:10:51.200 |
If you would like even more connections to the literature, 00:11:04.940 |
We're talking about intervention-based methods. 00:11:07.600 |
I claim that they can characterize representations richly. 00:11:10.920 |
After all, we show how those representations correspond 00:11:21.080 |
and I still have a smiley under improved models. 00:11:28.760 |
the heading of interchange intervention training. 00:11:38.360 |
directly on causal abstraction with interchange interventions. 00:11:44.840 |
interchange intervention using our addition example, 00:11:50.800 |
my intervention now for L3 has led to an incorrect result. 00:12:05.640 |
But I think you can also see in here an opportunity to do better. 00:12:09.680 |
We can correct this misalignment if we want to. 00:12:21.280 |
use to update the parameters of this model and make it 00:12:24.920 |
more conform to our underlying causal model under this alignment. 00:12:31.360 |
We get our error signal and that flows back as 00:12:38.080 |
For L1, the gradients flow back as usual to the input states. 00:12:45.480 |
But for L3, we have a more complicated update. 00:12:48.780 |
We have literally copied over the full computation graph in 00:12:52.800 |
the PyTorch sense including all the gradient information. 00:12:56.720 |
What we get for L3 is a double update coming from 00:13:00.560 |
our current example as well as the source example, 00:13:14.160 |
the causal model for the labels as we've done here, 00:13:17.340 |
is that we push the model to modularize information 00:13:25.000 |
The importance of alignments falls away and the emphasis 00:13:35.440 |
hypothesized in the hopes that they will then perform in 00:13:38.600 |
more systematic ways and be better at the tasks we've set for them. 00:13:46.960 |
We showed that IIT achieve state-of-the-art results on 00:13:51.440 |
that MNIST pointer value retrieval task that I mentioned before, 00:13:55.760 |
which is a grounded language understanding benchmark. 00:13:58.660 |
We also showed that IIT can be used as a distillation objective, 00:14:08.720 |
forcing them not only to conform in their input-output behavior, 00:14:21.840 |
a powerful distillation method and it also shows you that 00:14:25.560 |
the causal model that we use for IIT can be quite abstract. 00:14:29.480 |
In this case, it's just a high-level constraint on what we 00:14:32.780 |
want the teacher and student models to look like. 00:14:36.600 |
We also showed that IIT can be used to induce 00:14:42.280 |
language models that are based in subword tokenization. 00:14:52.860 |
Subword models seem to be our best language models, 00:14:56.240 |
but we have tasks that require knowledge of characters. 00:15:00.000 |
What we do with IIT is imbue these models with 00:15:02.640 |
knowledge of characters in their internal states. 00:15:09.760 |
concept level methods for explaining model behavior. 00:15:13.280 |
That's a technique that we call causal proxy models, 00:15:16.160 |
and it essentially leverages the core insight of IIT. 00:15:20.960 |
Again, we have this course notebook, IIT equality. 00:15:24.840 |
It covers abstraction analyses and then also shows 00:15:36.640 |
and I claim that I have justified all of those smileys. 00:15:39.840 |
I feel that this does point to intervention-based methods as