back to index

Stanford XCS224U: Analysis NLU, Part 5: Distributed Alignment Search (DAS) & Conclusion I Spring 23


Whisper Transcript | Transcript Only Page

00:00:00.000 | Welcome back everyone.
00:00:06.080 | This is the fifth and final screencast in
00:00:08.200 | our series on analysis methods for NLP.
00:00:10.620 | I'm going to seize this moment to tell you
00:00:12.660 | about a brand new method we've been developing,
00:00:15.080 | distributed alignment search.
00:00:17.000 | I think this overcomes crucial limitations with
00:00:19.800 | causal abstraction as I presented it to you before.
00:00:23.140 | I'm going to give you a high-level overview of
00:00:25.640 | DAS and then use that as
00:00:27.660 | a starting point to think even further
00:00:29.640 | into the future about analysis methods in the field.
00:00:33.440 | To start, let's return to our scorecard.
00:00:36.220 | I've been using this throughout the unit.
00:00:37.920 | I feel I have justified
00:00:39.820 | the three smileys along the interventions row,
00:00:42.960 | but there remain really
00:00:44.940 | pressing issues for this class of methods.
00:00:47.740 | Let me identify two of them.
00:00:49.500 | First, alignment search is expensive.
00:00:52.780 | The name of the game here is to define
00:00:54.720 | a high-level causal model and
00:00:56.920 | then align variables in that model
00:00:59.540 | with sets of neurons in the target neural model.
00:01:02.920 | For a complex causal model
00:01:05.100 | and a modern large language model,
00:01:07.460 | the number of possible alignments
00:01:09.300 | in that mode is enormous.
00:01:10.920 | I mean, to call it astronomical would be to fail to
00:01:14.580 | do justice to just how large the space is.
00:01:18.080 | As a result, we introduced lots of
00:01:20.340 | approximations and we could easily
00:01:22.280 | miss a really good alignment.
00:01:25.260 | The second issue is even deeper.
00:01:28.120 | Causal abstraction could fail to
00:01:30.460 | find genuine causal structure because
00:01:32.880 | we presume that we're doing it in a standard basis.
00:01:36.280 | The central insight behind DAS is that there might be
00:01:39.820 | interpretable structure that we would find if we
00:01:42.760 | simply rotated some of
00:01:44.340 | these representations in various ways.
00:01:47.060 | In fact, the target of DAS is
00:01:49.100 | a rotation matrix that we learn,
00:01:51.480 | and that helps us find optimal alignments.
00:01:55.600 | I'm going to keep this high-level and
00:01:58.060 | intuitive and for that I'll have
00:01:59.780 | a running example of very simple one.
00:02:02.260 | This is Boolean conjunction.
00:02:03.980 | I have the causal model on the left here.
00:02:06.320 | It takes in Boolean variables.
00:02:08.460 | It has intermediate variables for
00:02:10.700 | those inputs and then it outputs true.
00:02:13.380 | If both of the inputs were true, otherwise false.
00:02:16.960 | On the right, I have a very simple neural model.
00:02:20.860 | The neural model perfectly solves
00:02:23.440 | our Boolean conjunction task with this set of parameters.
00:02:27.780 | That's the starting point.
00:02:30.140 | Now, in the classical causal abstraction mode,
00:02:33.600 | I could define an alignment
00:02:35.580 | like this with these red arrows.
00:02:37.460 | It looks good. I align the inputs as you would expect.
00:02:40.700 | I align the outputs as you would expect,
00:02:43.080 | and I'll add in the decision procedure that if
00:02:45.700 | the neural network outputs a negative value,
00:02:48.080 | that's false for the causal model,
00:02:50.220 | and if it outputs a positive value,
00:02:52.260 | that's true for the causal model.
00:02:54.100 | That's intuitive.
00:02:55.380 | Then I did the intuitive thing of
00:02:57.640 | aligning V1 with H1 and V2 with H2.
00:03:02.260 | Now, this model is perfect behaviorally as I said,
00:03:07.380 | but it does not abstract
00:03:09.260 | the neural model under our chosen alignment.
00:03:12.300 | That chosen alignment bit is
00:03:13.780 | crucial and I'll just give you the spoiler here.
00:03:16.580 | What I did inadvertently
00:03:19.420 | is reverse the order of those internal variables.
00:03:22.700 | I should have mapped V1 to H2 and V2 to H1.
00:03:27.180 | What we're doing with this simple example is simulating
00:03:30.280 | the situation in which I just made
00:03:32.680 | a mistake about what set of alignments I decided to look at,
00:03:35.960 | and I picked one that is suboptimal in terms of
00:03:38.840 | finding structure that is actually there.
00:03:42.160 | The promise of DAS is that
00:03:44.680 | even if I start with this incorrect alignment,
00:03:47.240 | a rotation will help me find it.
00:03:50.480 | First, I'll just substantiate for you that we do
00:03:53.460 | indeed have a failure of causal abstraction.
00:03:56.160 | I'll show you a failed interchange intervention.
00:03:58.800 | On the top, as usual,
00:04:00.440 | we do an intervention with our causal model.
00:04:02.960 | We take V1 from the right example and put it into
00:04:06.020 | the corresponding place in the left example.
00:04:09.680 | The original output for the left example was false,
00:04:12.760 | but because of the intervention,
00:04:14.600 | we should get the output value true.
00:04:17.080 | When we do the corresponding
00:04:19.600 | aligned intervention on the neural model,
00:04:22.700 | we end up with an output state that is negative.
00:04:25.620 | That means predicting false,
00:04:27.420 | but the causal model said we should predict true,
00:04:30.000 | and that's exactly the kind of thing that leads us to say that this is
00:04:33.000 | not in the abstraction relationship,
00:04:35.600 | this causal model and this neural one.
00:04:38.360 | The two models have unequal counterfactual predictions.
00:04:43.340 | That is the heart of it. But remember,
00:04:44.940 | we already know why they do.
00:04:46.680 | It's because I chose the wrong alignment due to
00:04:49.840 | bad luck or research shortages or whatever.
00:04:53.480 | Here's the crucial insight behind DAS.
00:04:57.000 | The alignment relationship does hold in a non-standard basis.
00:05:01.400 | If I take the current network and the current alignment and I
00:05:04.320 | simply rotate H1 and H2 using this rotation matrix,
00:05:09.260 | then I have a network that is behaviorally perfect
00:05:12.840 | and satisfies the causal abstraction relationship.
00:05:16.880 | Causal abstraction classical mode
00:05:19.560 | missed this because of the standard basis we chose,
00:05:23.140 | and the essence of that is that there was
00:05:25.260 | no reason to choose the standard basis.
00:05:27.340 | It's intuitive for us as humans,
00:05:29.500 | but there's no reason to presume that
00:05:31.360 | our neural models prefer to operate in that basis.
00:05:34.680 | This example reveals that we might find
00:05:37.820 | interpretable structure by dropping that assumption about the basis.
00:05:43.160 | The essence of DAS,
00:05:45.400 | keep an eye on the ball here,
00:05:47.280 | is really learning this rotation matrix.
00:05:50.460 | That is the target of learning in DAS,
00:05:53.700 | and then the rotation matrix becomes the asset that you can use for
00:05:57.720 | actually finding and displaying and
00:05:59.820 | assessing internal causal structure.
00:06:03.320 | Here's a more high-level abstract overview of how this
00:06:06.920 | might happen using a pair of aligned interventions.
00:06:10.400 | I have my target model in red here.
00:06:12.940 | I have two source models on the left and right.
00:06:16.260 | They process their various examples and we're going to
00:06:19.020 | target the variables X_1,
00:06:21.040 | X_2, and X_3 across these different uses of the model.
00:06:24.840 | The first thing we do is rotate
00:06:28.200 | that representation that we targeted to create some new variables,
00:06:32.000 | Y_1, Y_2, and Y_3.
00:06:34.600 | Remember R here is the essence of DAS,
00:06:37.840 | and that is the matrix that we're going to learn
00:06:40.280 | using essentially interchange intervention training.
00:06:43.960 | Having done this rotation,
00:06:45.840 | I then create a new matrix that comes from
00:06:48.860 | me deciding to do an intervention with Y_1,
00:06:51.820 | with Y_2, and then copying Y_3 over from this core base example.
00:06:58.180 | That gives me this new vector here,
00:06:59.960 | and then we un-rotate and we do the intervention.
00:07:05.160 | Remember the essence of DAS is that we're going to freeze the model parameters.
00:07:09.820 | This is an analysis method,
00:07:11.880 | not a method where we change the core underlying target model.
00:07:15.540 | But the thing that we do is learn a rotation matrix that essentially
00:07:19.960 | maximizes the interchange intervention accuracy that we get from
00:07:24.720 | doing this rotation and then un-rotation to create these new models.
00:07:29.800 | This is a blend of IIT-like techniques,
00:07:33.340 | as well as classical causal abstraction.
00:07:35.940 | We keep the model frozen because we want to interpret it,
00:07:38.780 | but we learn that rotation matrix.
00:07:41.300 | That's the essence of DAS.
00:07:44.240 | Findings of DAS so far,
00:07:46.680 | these are pretty nuanced.
00:07:48.160 | In our DAS paper,
00:07:50.160 | we show that models learn truly hierarchical solutions to a hierarchical equality task.
00:07:56.020 | This is in fact the one that's reviewed in our notebook for this course.
00:07:59.980 | But those solutions are easy to miss with
00:08:02.980 | standard causal abstraction because of this non-standard basis issue.
00:08:07.960 | Here's a more nuanced finding.
00:08:10.160 | In earlier causal abstraction work,
00:08:12.560 | we found that models learn theories of lexical entailment and negation that
00:08:17.080 | align with a high-level intuitive causal model.
00:08:21.040 | But with DAS, we can uncover that they do that in a brittle way that actually
00:08:26.160 | preserves the identities of the lexical items rather than
00:08:29.340 | truly learning a general solution to the entailment issue.
00:08:34.320 | The third finding is from a separate paper.
00:08:37.900 | This is tremendously exciting because it shows that we can
00:08:40.740 | scale to levels that were impossible
00:08:43.360 | before due to our lack of a need for searching for alignments,
00:08:46.600 | because now we essentially learn the alignment.
00:08:49.460 | We scaled DAS to alpaca and we discovered that alpaca,
00:08:53.840 | a seven billion parameter model,
00:08:55.660 | implements an intuitive algorithm to solve a numerical reasoning task.
00:09:00.860 | I think this is just the start of the potential that we see for using
00:09:05.160 | DAS to understand our biggest and most performant
00:09:09.100 | and most interesting large language models.
00:09:13.140 | Let me turn now to wrapping up just some high-level conclusions here.
00:09:18.100 | First, I wanted to return to this diagram that I used
00:09:21.500 | to motivate analysis methods in general.
00:09:24.600 | We have these incredibly important goals for the field,
00:09:27.860 | identifying approved and disapproved uses,
00:09:30.640 | identifying and correcting pernicious social biases,
00:09:34.480 | and guaranteeing models as safe in certain contexts.
00:09:38.260 | I feel that we cannot offer guarantees about these issues
00:09:42.220 | unless we have analytic guarantees about the underlying models.
00:09:46.460 | For me, that implies a truly deep causal understanding
00:09:50.760 | of the mechanisms that shape their input-output behavior.
00:09:54.180 | For that reason, I think the analysis project in
00:09:58.300 | NLP is one of the most pressing projects for the field.
00:10:02.860 | In that spirit, let's look ahead a little bit to
00:10:05.440 | the near future of explainability research for the field.
00:10:09.220 | First, as I said,
00:10:10.440 | we should be seeking causal explanations,
00:10:13.680 | but we also need human interpretable ones.
00:10:16.320 | If causality were the only requirement,
00:10:19.040 | we could just give low-level mechanistic,
00:10:21.960 | mathematical explanations of how the transformer
00:10:24.840 | worked and call that explainability research.
00:10:27.880 | But that's at the wrong level for humans trying to offer
00:10:31.320 | guarantees about safety and trustworthiness.
00:10:34.520 | We need human interpretable explanations.
00:10:37.800 | We need to apply these methods to ever larger instruct-trained LLMs.
00:10:43.200 | Those are the most relevant artifacts for the current moment.
00:10:47.280 | I think we're starting to approach this goal with DAS.
00:10:50.760 | I just mentioned how we can apply it to alpaca.
00:10:53.580 | I think we could scale even further,
00:10:55.800 | but we really want to be
00:10:57.160 | unconstrained in terms of what we can explore,
00:10:59.520 | and that requires a lot more innovation in the space.
00:11:03.520 | Then finally, to return to
00:11:06.360 | the previous unit and our discussion of
00:11:08.360 | cogs and recogs and compositionality,
00:11:11.280 | I think we're seeing increasing evidence that
00:11:13.800 | models are inducing a semantics,
00:11:16.640 | that is a mapping from language into a network of concepts.
00:11:20.660 | If they are doing that and if we can find strong evidence for that,
00:11:24.800 | it's tremendously eye-opening about what kinds of
00:11:27.800 | data-driven learning processes could lead
00:11:30.440 | a language technology to actually
00:11:32.880 | have induced a semantics from its experiences,
00:11:35.360 | which would in turn lead us down the road of having
00:11:38.480 | many more guarantees that their behavior would be systematic,
00:11:42.280 | which could be a basis for them being, again,
00:11:44.880 | trustworthy, safe, and secure,
00:11:47.480 | and all of those important goals for the field and for society.
00:11:52.720 | [BLANK_AUDIO]