A Method to find Bilingual Features in Sparse autoencoders

A systematic, data driven process to find Bilingual features inside GemmaScope Sparse autoencoder models.
en
NLP
Mechanistic Interpretability
SAE
Multilinguality
Research Style Blog
Author
Affiliation

Diego Andrés Gómez Polo

Rappi

Published

September 29, 2024

Abstract

This blog post presents a systematic, data-driven process to generate a list of candidate bilingual features from a GemmaScope Sparse autoencoder. We define a bilingual interpretability score for each feature, which is dependent on a dataset of equivalent English-Spanish sentences. We then rank the features based on this score and analyze them. Finally, we discuss the potential for extending this methodology to include more than 2 languages.

Reproducibility

To reproduce all the results, feel free to use this Colab Notebook. But, be aware that in order to run the part of the code that gathers the activations, you will need around 24-25GB of RAM in CPU or close to that VRAM if on GPU. The colab free tier does not provide this amount of resources. You can still run the analysis part of the code with this dataset. The latter will run on almost any relatively modern computer.

Introduction

Sparse autoencoders (SAEs) trained on the attention heads and residual streams of large language models have shown great promise at producing seemingly interpretable features (Cunningham et al., 2023). Features gathered from SAEs can be used to understand the inner workings of large language models and even to steer their behaviour in a desired direction (Templeton et al., 2024).

Is not uncommon to find that some of the features learned by SAEs are multilingual, this is particularly interesting because it suggests that the model has learned to represent and reason through concepts in an abstract way that is independent of the language it is written in. The multilinguality of features, can be viewed as evidence for the universality of features hypothesis, which states that learned representations are universal and can form across models and tasks. This is one of the main speculative claims of the mechanistic interpretability agenda (Olah et al., 2020).

Figure 1: Rough Illustration of a Hooked SAE

But, how can we find these multilingual features in a SAE?

Much of the recent work regarding SAEs and mechanistic interpretability, has been about either scaling up the models to make them more powerful (Templeton et al., 2024) (Gao et al., 2024), finding techniques to make the models better at reconstructing the input (Rajamanoharan, Lieberum, et al., 2024), or using the learned features to find interesting circuits in the model (Wang et al., 2022). Many of such endevours end up always finding some multilingual features, but they are not the main focus of the work, nor are they systematically searched for.

In this work, we present a systematic, data-driven process to generate a list of candidate bilingual features from a GemmaScope SAE. We define a bilingual interpretability score for each feature, which is dependent on a dataset of equivalent English-Spanish sentences. We then rank the features based on this score and analyze them. Finally, we discuss the potential for extending this methodology to include more that 2 languages.

Methodology

The driving intuition behind our methodology is that, inspite of changes in tokenization, word order, general linguistic structure, and even the distribution of feature logits across languages, for a feature to be bilingual, it is necessary that it circumvents these differences and be activated by the same or similar sentences in both languages.

Such condition may not be sufficient, but as we will see, it is a good starting point to find bilingual features in a SAE.

In this section, we will describe the specific methodology that arises from this intuition, which consists of three main steps:

  1. Data Collection: We gather a dataset of equivalent English-Spanish sentences.
  2. Feature Extraction: We extract the features from the SAE for each sentence in the dataset.
  3. Bilingual Interpretability Score: We define a score that measures how similar the activations of a feature are across languages.

Basic Setup

For our experiments, we used Gemma 2-2B as our language model on its base pretrained version without any intruction tuning (See Figure 1). We focused our experiments on a single SAE from the GemmaScope collection of open-source SAEs (Lieberum et al., 2024), specifically, the one with id gemma-scope-2b-pt-res-canonical/layer_20/width_16k/canonical. This SAE has 16k features and is trained on the residual stream of the 20th layer of the model. It is the smallest version of this particular hook point, and the choice for its size was made purely for computational reasons.

The 2B version of the Gemma 2 models has 26 layers (Team et al., 2024), so a SAE trained on the 20th residual stream is expected to have learned more abstract features than earlier layers. Moreover, we decided to use the residual stream instead of the attention heads because it is an information bottleneck where not only the prior attention head writes to, but also all the later ones, so one should expect that the features learned in this specific point are more abstract and general than those inside the attention mechanism (Elhage et al., 2021).

For our bilingual dataset, we used a small sample of the OPUS Books dataset (Tiedemann, 2012), with equivalent English-Spanish sentences.

Feature Extraction

To extract the features from the SAE, we used the HookedSAETransformer class from the SAELens library (Bloom and Chanin, 2024). This class allows us to hook our SAE to a given language model, and cache the activations of the SAE for a given set of inputs.

We ran the HookedSAETransformer on the English and Spanish sample pairs, and stored the activations using the datasets library from Hugging Face. This data is publicly available at hugging face hub under the name diegomezp/gemmascope_bilingual_activations. It contains not only the activations of the SAE for the sample pairs, but also the token ids of each sentence.

Show source
import os
from huggingface_hub import login
from datasets import load_dataset
from dotenv import load_dotenv
import torch
import plotly.graph_objects as go
import numpy as np
import warnings

warnings.filterwarnings("ignore")

# load the environment variables
load_dotenv(override=True)

# login to hugging face
hf_token = os.getenv("HF_TOKEN")
login(token=hf_token, add_to_git_credential=True)

# download the dataset
sample_ds = load_dataset("diegomezp/gemmascope_bilingual_activations").with_format(
    "torch"
)
sample_ds = sample_ds["train"]

activation_tensor = torch.nested.nested_tensor(
    sample_ds["sae_features"]
).to_padded_tensor(0.0)


def get_single_lang_statistics(activation_tensor: torch.tensor) -> dict:
    """
    Input:
      activation_tensor (torch.tensor float32): Tensor of size (samples, tokens, features)

    Output:
      (dict) : {
          "mean": {
            "value": float,
            "series": tensor size(|features|)
          },
          "q_0.05": # same structure as mean,
          "q_0.25": # same,
          "q_0.50": # same,
          "q_0.75": # same,
          "q_0.95": # same,
        }
    """
    s, t, f = activation_tensor.size()
    # Get quantils only for those logits > 0
    activation_logits = activation_tensor[activation_tensor > 0]
    mean_act = activation_logits.mean()
    quantiles = [0.05, 0.25, 0.5, 0.75, 0.95]
    quantiles_values = torch.quantile(activation_logits, torch.tensor(quantiles))

    thresholds = {"mean": mean_act}
    thresholds.update({f"q_{q}": v for q, v in zip(quantiles, quantiles_values)})
    response = dict()
    max_activations = activation_tensor.max(dim=1).values  # size (s, f)

    for name, threshold in thresholds.items():
        response[name] = dict(value=threshold.item())
        final_activations = (
            (max_activations > threshold).to(float).mean(dim=0)
        )  # size f
        response[name]["series"] = final_activations.sort(descending=True).values
    return response


def get_activation_statistics(activation_tensor: torch.tensor) -> dict:
    """
    Both globally and for each language dimension we will get as much as 4 series
    of size |features|. Each of those series will represent what percentage of
    the samples had at least one activation of a given feature.
    The ordering of the features in such tensor will also be returned, and we will
    use quantiles and mean for setting the activation threshold:

    Input:
      activation_tensor (torch.tensor): Tensor of size (samples, languages, tokens, features)

    Response:
      (dict) : {
        "stats: {
            "global": {
              "mean": {
                "value": float,
                "series": tensor size(|features|)
              },
              "q_0.05": # same structure as mean,
              "q_0.25": # same,
              "q_0.50": # same,
              "q_0.75": # same,
              "q_0.95": # same,
            },
            "lang_0": # Same structure as before,
            ...
            "lang_n": # Same as before
          }
        }
    """
    assert len(activation_tensor.size()) == 4, (
        "ActivationTensor must have 4 dims" "(samples, languages, tokens, features)"
    )
    s, l, t, f = activation_tensor.size()
    response = dict()
    response["stats"] = dict()
    response["stats"]["global"] = get_single_lang_statistics(
        activation_tensor.reshape(-1, t, f)
    )

    for idx in range(l):
        response["stats"][f"lang_{idx}"] = get_single_lang_statistics(
            activation_tensor[:, idx, :, :]
        )

    return response


activation_stats = get_activation_statistics(
    activation_tensor[:, :, 1:, :]
)  # Ignoring BOS token


# Determine the number of groups and tensors
num_groups = len(activation_stats["stats"])
num_tensors = len(activation_stats["stats"]["global"])

titles = {
    "global": "Percentage of samples that each feature activated (ordered)",
    "lang_0": "Percentage of activated Spanish samples (ordered)",
    "lang_1": "Percentage of activated English samples (ordered)",
}


def print_activation_stats(group_name, group_data):
    fig = go.Figure()
    for series_name, series_data in group_data.items():
        fig.add_trace(
            go.Scatter(
                x=np.arange(series_data["series"].size(0)),  # X-axis: indices of tensor
                y=series_data["series"].numpy(),  # Y-axis: tensor values
                mode="lines",  # Line plot
                name=f"{series_name}={series_data['value']:.2f}",  # Name for the legend
                showlegend=True,
            )
        )

    # Update layout
    fig.update_layout(
        title=titles[group_name],
        xaxis_title="Ordered SAE Features",
        yaxis_title="Percentage of Activating Samples",
        legend_title="Activation Threshold",
        template="plotly_white",
        xaxis_type="log",
    )

    fig.show()


# print_activation_stats("global", activation_stats["stats"]["global"])
Token is valid (permission: write).
Your token has been saved in your configured git credential helpers (osxkeychain).
Your token has been saved to /Users/diego.gomez/.cache/huggingface/token
Login successful
Show source
print_activation_stats("lang_0", activation_stats["stats"]["lang_0"])