Back to Index

Stanford XCS224U: Analysis NLU, Part 5: Distributed Alignment Search (DAS) & Conclusion I Spring 23


Transcript

Welcome back everyone. This is the fifth and final screencast in our series on analysis methods for NLP. I'm going to seize this moment to tell you about a brand new method we've been developing, distributed alignment search. I think this overcomes crucial limitations with causal abstraction as I presented it to you before.

I'm going to give you a high-level overview of DAS and then use that as a starting point to think even further into the future about analysis methods in the field. To start, let's return to our scorecard. I've been using this throughout the unit. I feel I have justified the three smileys along the interventions row, but there remain really pressing issues for this class of methods.

Let me identify two of them. First, alignment search is expensive. The name of the game here is to define a high-level causal model and then align variables in that model with sets of neurons in the target neural model. For a complex causal model and a modern large language model, the number of possible alignments in that mode is enormous.

I mean, to call it astronomical would be to fail to do justice to just how large the space is. As a result, we introduced lots of approximations and we could easily miss a really good alignment. The second issue is even deeper. Causal abstraction could fail to find genuine causal structure because we presume that we're doing it in a standard basis.

The central insight behind DAS is that there might be interpretable structure that we would find if we simply rotated some of these representations in various ways. In fact, the target of DAS is a rotation matrix that we learn, and that helps us find optimal alignments. I'm going to keep this high-level and intuitive and for that I'll have a running example of very simple one.

This is Boolean conjunction. I have the causal model on the left here. It takes in Boolean variables. It has intermediate variables for those inputs and then it outputs true. If both of the inputs were true, otherwise false. On the right, I have a very simple neural model. The neural model perfectly solves our Boolean conjunction task with this set of parameters.

That's the starting point. Now, in the classical causal abstraction mode, I could define an alignment like this with these red arrows. It looks good. I align the inputs as you would expect. I align the outputs as you would expect, and I'll add in the decision procedure that if the neural network outputs a negative value, that's false for the causal model, and if it outputs a positive value, that's true for the causal model.

That's intuitive. Then I did the intuitive thing of aligning V1 with H1 and V2 with H2. Now, this model is perfect behaviorally as I said, but it does not abstract the neural model under our chosen alignment. That chosen alignment bit is crucial and I'll just give you the spoiler here.

What I did inadvertently is reverse the order of those internal variables. I should have mapped V1 to H2 and V2 to H1. What we're doing with this simple example is simulating the situation in which I just made a mistake about what set of alignments I decided to look at, and I picked one that is suboptimal in terms of finding structure that is actually there.

The promise of DAS is that even if I start with this incorrect alignment, a rotation will help me find it. First, I'll just substantiate for you that we do indeed have a failure of causal abstraction. I'll show you a failed interchange intervention. On the top, as usual, we do an intervention with our causal model.

We take V1 from the right example and put it into the corresponding place in the left example. The original output for the left example was false, but because of the intervention, we should get the output value true. When we do the corresponding aligned intervention on the neural model, we end up with an output state that is negative.

That means predicting false, but the causal model said we should predict true, and that's exactly the kind of thing that leads us to say that this is not in the abstraction relationship, this causal model and this neural one. The two models have unequal counterfactual predictions. That is the heart of it.

But remember, we already know why they do. It's because I chose the wrong alignment due to bad luck or research shortages or whatever. Here's the crucial insight behind DAS. The alignment relationship does hold in a non-standard basis. If I take the current network and the current alignment and I simply rotate H1 and H2 using this rotation matrix, then I have a network that is behaviorally perfect and satisfies the causal abstraction relationship.

Causal abstraction classical mode missed this because of the standard basis we chose, and the essence of that is that there was no reason to choose the standard basis. It's intuitive for us as humans, but there's no reason to presume that our neural models prefer to operate in that basis.

This example reveals that we might find interpretable structure by dropping that assumption about the basis. The essence of DAS, keep an eye on the ball here, is really learning this rotation matrix. That is the target of learning in DAS, and then the rotation matrix becomes the asset that you can use for actually finding and displaying and assessing internal causal structure.

Here's a more high-level abstract overview of how this might happen using a pair of aligned interventions. I have my target model in red here. I have two source models on the left and right. They process their various examples and we're going to target the variables X_1, X_2, and X_3 across these different uses of the model.

The first thing we do is rotate that representation that we targeted to create some new variables, Y_1, Y_2, and Y_3. Remember R here is the essence of DAS, and that is the matrix that we're going to learn using essentially interchange intervention training. Having done this rotation, I then create a new matrix that comes from me deciding to do an intervention with Y_1, with Y_2, and then copying Y_3 over from this core base example.

That gives me this new vector here, and then we un-rotate and we do the intervention. Remember the essence of DAS is that we're going to freeze the model parameters. This is an analysis method, not a method where we change the core underlying target model. But the thing that we do is learn a rotation matrix that essentially maximizes the interchange intervention accuracy that we get from doing this rotation and then un-rotation to create these new models.

This is a blend of IIT-like techniques, as well as classical causal abstraction. We keep the model frozen because we want to interpret it, but we learn that rotation matrix. That's the essence of DAS. Findings of DAS so far, these are pretty nuanced. In our DAS paper, we show that models learn truly hierarchical solutions to a hierarchical equality task.

This is in fact the one that's reviewed in our notebook for this course. But those solutions are easy to miss with standard causal abstraction because of this non-standard basis issue. Here's a more nuanced finding. In earlier causal abstraction work, we found that models learn theories of lexical entailment and negation that align with a high-level intuitive causal model.

But with DAS, we can uncover that they do that in a brittle way that actually preserves the identities of the lexical items rather than truly learning a general solution to the entailment issue. The third finding is from a separate paper. This is tremendously exciting because it shows that we can scale to levels that were impossible before due to our lack of a need for searching for alignments, because now we essentially learn the alignment.

We scaled DAS to alpaca and we discovered that alpaca, a seven billion parameter model, implements an intuitive algorithm to solve a numerical reasoning task. I think this is just the start of the potential that we see for using DAS to understand our biggest and most performant and most interesting large language models.

Let me turn now to wrapping up just some high-level conclusions here. First, I wanted to return to this diagram that I used to motivate analysis methods in general. We have these incredibly important goals for the field, identifying approved and disapproved uses, identifying and correcting pernicious social biases, and guaranteeing models as safe in certain contexts.

I feel that we cannot offer guarantees about these issues unless we have analytic guarantees about the underlying models. For me, that implies a truly deep causal understanding of the mechanisms that shape their input-output behavior. For that reason, I think the analysis project in NLP is one of the most pressing projects for the field.

In that spirit, let's look ahead a little bit to the near future of explainability research for the field. First, as I said, we should be seeking causal explanations, but we also need human interpretable ones. If causality were the only requirement, we could just give low-level mechanistic, mathematical explanations of how the transformer worked and call that explainability research.

But that's at the wrong level for humans trying to offer guarantees about safety and trustworthiness. We need human interpretable explanations. We need to apply these methods to ever larger instruct-trained LLMs. Those are the most relevant artifacts for the current moment. I think we're starting to approach this goal with DAS.

I just mentioned how we can apply it to alpaca. I think we could scale even further, but we really want to be unconstrained in terms of what we can explore, and that requires a lot more innovation in the space. Then finally, to return to the previous unit and our discussion of cogs and recogs and compositionality, I think we're seeing increasing evidence that models are inducing a semantics, that is a mapping from language into a network of concepts.

If they are doing that and if we can find strong evidence for that, it's tremendously eye-opening about what kinds of data-driven learning processes could lead a language technology to actually have induced a semantics from its experiences, which would in turn lead us down the road of having many more guarantees that their behavior would be systematic, which could be a basis for them being, again, trustworthy, safe, and secure, and all of those important goals for the field and for society.