AI Interpretability and Safety: How Features Could Shape the Future of LLMs

Nov 22, 2024

AI Interpretability & Safety with Sparse Autoencoders in LLMs

I remember the first technical talk that I gave; it involved training a toy Convolutional Neural Network (CNN) to determine the right moment to press the UP arrow key and make the Chrome T-Rex jump in the browser. You show the CNN several images from actual gameplay and hope that it learns to press the UP button when it sees an obstacle. The simple model worked quite well and I was trying to understand which pixels in the image cause the activation of the neurons and therefore drive the model to generate the UP button press command.

While reading Anthropic's interpretability research, I could sense that they had similar intuitions but scaled it to a whole other level and were able to derive some cool insights! In their work, they try to determine if they can build intuitions for what features an LLM learns and whether controlling that feature leads to an impact in the generated output.

Key takeaways from my reading:

  • Activations at the individual neurons were not helpful, instead they come up with "features" - multiple linear combinations of neurons that capture interpretable patterns.

  • They decompose the activations of the middle layers of their Sonnet model using Sparse AutoEncoders (SAE). This helps to "explode" the tightly packed information within the neurons in this layer into more interpretable features. It's like unfolding a complex origami to reveal its underlying structure.

  • In order to test if these features had an impact, they manually adjusted (clamped up or down) the SAE weights for a particular feature and then fed this new output back into the remaining layers of the LLM to complete the forward pass.

  • By observing changes in the generated tokens, researchers could observe and successfully determine whether the identified feature exists, follows the same interpretation and has the desired impact in the output.

An example of such an identified feature is the Code Error Feature which is based on the activation of neurons when incorrect code snippets (across multiple programming languages) are passed through the LLM. They noticed that when this feature is clamped up by three times, it actually completes the tokens with an error message even though the code itself is accurate and would normally have generated the correct response.

How is this useful?

Knowing that such interpretable features can be identified throws more light into the "black box". More importantly, it provides a way to control the behaviour of the LLM by modifying such features. This is particularly relevant in the case of safety features (e.g. the ability of an LLM to create bio-hazard recipes). Currently, these safety behaviors are enforced via post-training an LLM (RLHF, red-teaming) but in the future, it might be possible to control such behavior directly via so called features.