← Back to home

How the Residual Stream is (Not) Linear

I commonly hear about the linearity of the residual stream of Transformer language models. Linearity, it is argued, is powerful for interpretability, and gives validity to interpretability tools like logit lens and steering vectors. In this brief note, we ask, linear with respect to what? We argue that the residual stream is not linear with respect to anything that gives us significant leverage.

The additive nature of updates to the residual stream is still interesting, but not as powerful as linearity. Futher, empirically, assuming additive changes correspond to generalizing behavior change seems to work in some cases, perhaps because of what functions are easy to learn with residual connections. However, I encourage researchers not to refer to the Transformer's architecture to assign interpretability/control value to methods that implicitly assume linearity, and to instead let those methods' empirical results stand alone.

If you're not familiar with these concepts, I encourage you to first read Elhage et al., 2021.

What this post is not about

This post is not about the linear representation hypothesis---roughly, that high-level concepts can empirically be intervened upon through additive changes. Nor is it about about superposition or polysemanticity. We're just interested in clarity of thought around the Transformer architecture.

The Residual Stream

By the residual stream, we mean the sequence of hidden states of a token as each layer processes it. For example, the residual stream for the $i$-th token in the $\ell$-th layer is given by:

\begin{align} &h^{(0)}_i = E w_i\\ &\tilde{h}^{(\ell)}_i = h^{(\ell-1)}_i + \text{MLP}(h^{(\ell-1)}_i)\\ &h^{(\ell)}_i = \tilde{h}^{(\ell)}_i + \text{Attn}(\tilde{h}^{(\ell)}_{1:i})_i \end{align}
where $\tilde{h}^{(\ell)}_i$ is the residual stream after the MLP contribution from the $\ell$-th layer, and $h^{(\ell)}_i$ is the residual stream after the attention contribution from the $\ell$-th layer, and $Ew_i$ is the embedding matrix indexed at word $w_i$. So, each sublayer of the Transformer block- both the MLP and the attention- adds a vector to the residual stream.

What is the residual stream linear with respect to?

If I take a late-layer activation $h^{(\ell)}_i$, I can write it out as a sum of all contributions from all previous layers:
\begin{align} h^{(\ell)}_i = \sum_{j=1}^{\ell} \left( \text{MLP}(h^{(j-1)}_i) + \text{Attn}(\tilde{h}^{(j)}_{1:i})_i\right) \end{align}

This is a linear combination of the contributions from all previous layers, with all weights $1$. We could write this as linear operation with a tensor $\mathcal{H}\in\mathbb{R}^{d\times 2\ell}$ which is the result of stacking all the sublayers, and $\mathbf{1}^{2\ell}$, a $2\ell$-dimensional vector of 1s. So, $\mathcal{H}$ looks like
\begin{align} \mathcal{H} = \begin{bmatrix} | & | & | & | & & | & | \\ \text{MLP}^{(1)} & \text{Attn}^{(1)}_i & \text{MLP}^{(2)} & \text{Attn}^{(2)}_i & \cdots & \text{MLP}^{(\ell)} & \text{Attn}^{(\ell)}_i \\ | & | & | & | & & | & | \end{bmatrix} \end{align}
and we have
\begin{align} h^{(\ell)}_i = \mathcal{H}\mathbf{1} \end{align}
Ok, so we can write any layer as linear in the previous layers at that index. What leverage does this give us? Well, consider the following intervention that is common in interpretability, both in steering vectors and in sparse autoencoder interventions: an additive intervention:
\begin{align} h'^{(q)}_i \leftarrow h^{(q)}_i + v \end{align}
For layer $\ell>q$ later in the network, does this mean that the way that $h^{(\ell)}$ changes will be predictable somehow? As in:
\begin{align} {\color{#c0392b}{h'^{(\ell)}_i}} \stackrel{?}{=} \mathcal{H}\mathbf{1} + {\color{#c0392b}{v}} \end{align}
Or
\begin{align} {\color{#c0392b}{h'^{(\ell)}_i}} \stackrel{?}{=} \mathcal{H}\mathbf{1} + {\color{#c0392b}f({v})} \end{align}
for some simple $f$? Neither is the case. We no longer have $\mathcal{H}$, because all layers after $q$ have been modified in non-linear ways. This is due to (1) the non-linearity in the MLP, (2) the non-linearity in the layer or RMS norm, and (3) even the attention mechanism drawing information from other indices. But mostly due to (1) and (2). We have
\begin{align} {\color{#c0392b}\mathcal{H}'} = \begin{bmatrix} | & | & | & | & & {\color{#c0392b}|} & {\color{#c0392b}|} & & {\color{#c0392b}|} & {\color{#c0392b}|} \\ \text{MLP}^{(1)} & \text{Attn}^{(1)}_i & \cdots & \text{MLP}^{(q)} & \text{Attn}^{(q)}_i & {\color{#c0392b}\text{MLP}'^{(q+1)}} & {\color{#c0392b}\text{Attn}'^{(q+1)}_i} & \cdots & {\color{#c0392b}\text{MLP}'^{(\ell)}} & {\color{#c0392b}\text{Attn}'^{(\ell)}_i} \\ | & | & | & | & & {\color{#c0392b}|} & {\color{#c0392b}|} & & {\color{#c0392b}|} & {\color{#c0392b}|} \end{bmatrix} \end{align}
Indeed, the way (sub)layer contributions after $h^{(q)}_i$ change reflect their corresponding functions' non-linearity, and the Attention sublayers even read differently from all the other tokens in the sequence! So, the residual stream is linear with respect to all sublayer contributions at that index, but because of the non-linear dependencies between the layers, this doesn't give us any guarantees .

To really drive the point home, note that this has nothing to do with residual connections. A neural network with no residual connections is also linear with respect to all sublayer contributions at that index in the same way.. To see this, imagine a feed-forward network with no residual connection. The only difference is that we replace the linear weights $\mathbf{1}$ with a vector $[0;\cdots;0;1;\cdots]$ with exactly $2(\ell-1)d$ zeros and $d$ ones. That is, the weights on previous sublayers are zero, and the weights on the most recent sublayer are $1$.

So, if this is the way Transformer residual connections are deeply linear, then all other neural network activation sequences are also deeply linear.

The residual stream is not linear with respect to the previous layer.

The residual stream is not linear with respect to the previous layer, even if we allow all token indices to be considered as inputs to the linear function. This may be familiar, but it bears repeating. The self-attention layer does compute a linear combination of a linear transformation of the vectors in the previous layer. However, non-linearity in the MLP and RMS/layer norm mean that the residual stream is not linear with respect to the previous layer.

Doesn't each sublayer perform a linear read?

I don't interpret it as such, even though it starts with a linear operation. Even though the MLP has the following form:
\begin{align} \text{MLP}(x) = \text{Norm}(W_2 \sigma(W_1 x)) \end{align}
where $W_1$ and $W_2$ are learned weight matrices, and $\sigma$ is a non-linear activation function, a function that starts with a linear transformation ($W_1$) is no more linear for it. MLPs can approximate rather complex functions with their high intermediate dimensionality. We can certainly make statements like "Anything outside the span of $W_2$ will not affect the output of the MLP", but if $W_2$ is rank $d$, this is vacuous since nothing in the input is outside its span.

Doesn't each sublayer perform a linear write?

Again, I don't interpret it as such. Again we ask, a linear function of what? The Attn block feels like it should be a linear combination of the value matrix-multiplied vectors of the previous layer (mixed with the output matrix), but this isn't the case because of the layer/RMS norms. The MLP block feels like it should be a linear combination of the vectors of $W_2$ of the MLP (defined above), that is, the output vectors of the MLP weighted by the neurons in the intermediate state of the MLP, $\sigma(W_1 x)$. If it were, that would be mildly interesting, even though that combination is non-linearly dependent on the input vector. However, it's not even that, again because of the layer/RMS norm. Even after that, we apply a layer/RMS norm, so the contribution isn't even linear in the intermediate state $\sigma(W_1 x)$.

Instead, we're back to the fact that the sublayer adds a vector to the residual stream. And again we ask, what leverage do we gain from this?

Wait, so why do logit lens and steering vectors and SAEs work?

It depends on your definition of "work", but indeed it seems that interention on very interesting high-level concepts can be approximated to some extent by additive interventions on the residual stream. This is a fascinating set of empirical results. However, there is no a priori linearity-related reason to prefer these results. Perhaps these interventions work well; perhaps we like additive interventions for some separate aesthetic reason, but the methods must be evaluated according to their own merits.

That is, the logit lens is not somehow a priori a valid representation of what word the model is trying/going to say at that point in the network; this would require linearity between the last layer of the network and the middle layers. As we've seen, there's no such linearity.

In the end, it may be because of optimization pressures and gradient descent. The residual connections of a Transformer very likely make it easier to learn functions that leverage the additive flow of information and thus make additive interventions particularly (sort of) effective. This would be great! But the utility of this (approximate) result is only shown through methods that work, say like steering that works better than prompting or re-doing the RL process.

So what can we assume?

First, the predictions are log-linear in the last layer of the network. This is really cool, leading to intuitions about the softmax bottleneck, and in figuring out the model dimensionality of proprietary models.

Second, empirical intuitions about how additive changes to the residual stream seem to propagate nicely deeper into the network are still valid; they just shouldn't be taken for granted.

Let's build interpretability methods that we can prove---or show empirically---generalize better, and are more precise.

References