How the Residual Stream is (Not) Linear
I commonly hear about the linearity of the residual stream of Transformer language models. Linearity, it is argued, is powerful for interpretability, and gives validity to interpretability tools like logit lens and steering vectors. In this brief note, we ask, linear with respect to what? We argue that the residual stream is not linear with respect to anything that gives us significant leverage.
The additive nature of updates to the residual stream is still interesting, but not as powerful as linearity. Futher, empirically, assuming additive changes correspond to generalizing behavior change seems to work in some cases, perhaps because of what functions are easy to learn with residual connections. However, I encourage researchers not to refer to the Transformer's architecture to assign interpretability/control value to methods that implicitly assume linearity, and to instead let those methods' empirical results stand alone.
If you're not familiar with these concepts, I encourage you to first read Elhage et al., 2021.
What this post is not about
This post is not about the linear representation hypothesis---roughly, that high-level concepts can empirically be intervened upon through additive changes. Nor is it about about superposition or polysemanticity. We're just interested in clarity of thought around the Transformer architecture.The Residual Stream
By the residual stream, we mean the sequence of hidden states of a token as each layer processes it. For example, the residual stream for the $i$-th token in the $\ell$-th layer is given by:
What is the residual stream linear with respect to?
If I take a late-layer activation $h^{(\ell)}_i$, I can write it out as a sum of all contributions from all previous layers:To really drive the point home, note that this has nothing to do with residual connections. A neural network with no residual connections is also linear with respect to all sublayer contributions at that index in the same way.. To see this, imagine a feed-forward network with no residual connection. The only difference is that we replace the linear weights $\mathbf{1}$ with a vector $[0;\cdots;0;1;\cdots]$ with exactly $2(\ell-1)d$ zeros and $d$ ones. That is, the weights on previous sublayers are zero, and the weights on the most recent sublayer are $1$.
So, if this is the way Transformer residual connections are deeply linear, then all other neural network activation sequences are also deeply linear.
The residual stream is not linear with respect to the previous layer.
The residual stream is not linear with respect to the previous layer, even if we allow all token indices to be considered as inputs to the linear function. This may be familiar, but it bears repeating. The self-attention layer does compute a linear combination of a linear transformation of the vectors in the previous layer. However, non-linearity in the MLP and RMS/layer norm mean that the residual stream is not linear with respect to the previous layer.Doesn't each sublayer perform a linear read?
I don't interpret it as such, even though it starts with a linear operation. Even though the MLP has the following form:Doesn't each sublayer perform a linear write?
Again, I don't interpret it as such. Again we ask, a linear function of what? The Attn block feels like it should be a linear combination of the value matrix-multiplied vectors of the previous layer (mixed with the output matrix), but this isn't the case because of the layer/RMS norms. The MLP block feels like it should be a linear combination of the vectors of $W_2$ of the MLP (defined above), that is, the output vectors of the MLP weighted by the neurons in the intermediate state of the MLP, $\sigma(W_1 x)$. If it were, that would be mildly interesting, even though that combination is non-linearly dependent on the input vector. However, it's not even that, again because of the layer/RMS norm. Even after that, we apply a layer/RMS norm, so the contribution isn't even linear in the intermediate state $\sigma(W_1 x)$.Instead, we're back to the fact that the sublayer adds a vector to the residual stream. And again we ask, what leverage do we gain from this?
Wait, so why do logit lens and steering vectors and SAEs work?
It depends on your definition of "work", but indeed it seems that interention on very interesting high-level concepts can be approximated to some extent by additive interventions on the residual stream. This is a fascinating set of empirical results. However, there is no a priori linearity-related reason to prefer these results. Perhaps these interventions work well; perhaps we like additive interventions for some separate aesthetic reason, but the methods must be evaluated according to their own merits.That is, the logit lens is not somehow a priori a valid representation of what word the model is trying/going to say at that point in the network; this would require linearity between the last layer of the network and the middle layers. As we've seen, there's no such linearity.
In the end, it may be because of optimization pressures and gradient descent. The residual connections of a Transformer very likely make it easier to learn functions that leverage the additive flow of information and thus make additive interventions particularly (sort of) effective. This would be great! But the utility of this (approximate) result is only shown through methods that work, say like steering that works better than prompting or re-doing the RL process.So what can we assume?
First, the predictions are log-linear in the last layer of the network. This is really cool, leading to intuitions about the softmax bottleneck, and in figuring out the model dimensionality of proprietary models.
Second, empirical intuitions about how additive changes to the residual stream seem to propagate nicely deeper into the network are still valid; they just shouldn't be taken for granted.
Let's build interpretability methods that we can prove---or show empirically---generalize better, and are more precise.References
- A Mathematical Framework for Transformer Circuits — Nelson Elhage, Neel Nanda, Catherine Olsson, et al., Anthropic, December 2021
- What is a Linear Representation? What is a Multidimensional Feature? — Chris Olah, Anthropic Circuits Updates, July 2024
- Non-Linear Feature Representations — Liv Gorton