Tokenized SAEs

18 Jul 2024

Thomas Dooms, Daniel Wilhelm

Background

Sparse auto-encoders have become the main focus of mechinterp research; they provide a means to extract interpretable bases from intermediate stages of the model. They achieve this by locally reconstructing activations. While this has been remarkably effective, this has some suboptimal side-effects, in this paper we focus on distribution-dependent features. We explain why these arise and propose a technique to separate some of them from the main reconstruction.

Motivation

The motivation for this work can be framed from multiple perspectives. Initially, it arose form spending quite a bit of time on Neuronpedia and seeing that the vast majority of features is token-based instead of context-based. From a training perspective, this makes sense due to the following two facts:

  • Token directions are generally the most important direction in the representation.
  • Token directions are more frequent than some specific context-based representation.

Regardless of the reason they exist, ideally, they wouldn’t clutter the (generally limited) learned features.

Method

I generally introduce the proposed method as a simple trick to remove these single-token features which seems to work surprisingly well. The main idea is to introduce a lookup table to the decoder that is indexed by the original token of the current activations. This table then takes care of the “base” reconstruction for each token. We denote the hidden activations that originate from some token tt as xtx_t and the sparsity/activation function as σ\sigma.

at=σ(Wenc(xtbdec))a_t = \sigma(W_{enc}(x_t - b_{dec}))

x^t=Wdec(at)+bdec+Wlookup(t)\hat{x}_t = W_{dec}(a_t) + b_{dec} + \mathbf{W_{lookup}(t)}

There are some slight caveats to training this lookup table, which are described in the paper. Outside of this, it’s really just that simple.

Results

We show adding a lookup table improves the final reconstructions by a significant margin on GPT-2 layer 8.

Pareto curves of the normalized mean squared error (top) and added cross entropy (bottom) across SAE architectures. All SAEs were trained on layer 8 of GPT-2 small with about 300M tokens.

Beyond this, by forcing the SAE to use directions we know to be useful, it is able to learn much more quickly. We measure how much faster TSAEs reach the final reconstruction value of their baseline variant.

Speedup of TSAE vs baseline
Speedup of TSAE vs baseline on GPT-2 small

This huge speedup results in TSAEs reaching really high fidelity reconstructions in only a few minutes across GPT-2 layers. We believe this will be really useful for quickly iterating on SAE suites.

Common intuition is that this lookup table would become less effective with depth or model complexity. However, our results show that Tokenized SAEs remain effective on Pythia 1.4B, even at layer 20. We again show the speedup.

Speedup of TSAE vs baseline
Speedup of TSAE vs baseline on Pythia 1.4B

Future Work

The general idea of incorporating inductive bias into SAEs seems interesting to pursue. Trivially, a more general (sparse) n-gram lookup table could be used. Furthermore, features from a preceding SAE could similarly be used into some kind of lookup table.