A Method to find Bilingual Features in Sparse autoencoders

A systematic, data driven process to find Bilingual features inside GemmaScope Sparse autoencoder models.
en
NLP
Mechanistic Interpretability
SAE
Multilinguality
Research Style Blog
Author
Affiliation

Diego Andrés Gómez Polo

Rappi

Published

September 29, 2024

Abstract

This blog post presents a systematic, data-driven process to generate a list of candidate bilingual features from a GemmaScope Sparse autoencoder. We define a bilingual interpretability score for each feature, which is dependent on a dataset of equivalent English-Spanish sentences. We then rank the features based on this score and analyze them. Finally, we discuss the potential for extending this methodology to include more than 2 languages.

Reproducibility

To reproduce all the results, feel free to use this Colab Notebook. But, be aware that in order to run the part of the code that gathers the activations, you will need around 24-25GB of RAM in CPU or close to that VRAM if on GPU. The colab free tier does not provide this amount of resources. You can still run the analysis part of the code with this dataset. The latter will run on almost any relatively modern computer.

Introduction

Sparse autoencoders (SAEs) trained on the attention heads and residual streams of large language models have shown great promise at producing seemingly interpretable features (Cunningham et al., 2023). Features gathered from SAEs can be used to understand the inner workings of large language models and even to steer their behaviour in a desired direction (Templeton et al., 2024).

Is not uncommon to find that some of the features learned by SAEs are multilingual, this is particularly interesting because it suggests that the model has learned to represent and reason through concepts in an abstract way that is independent of the language it is written in. The multilinguality of features, can be viewed as evidence for the universality of features hypothesis, which states that learned representations are universal and can form across models and tasks. This is one of the main speculative claims of the mechanistic interpretability agenda (Olah et al., 2020).

Figure 1: Rough Illustration of a Hooked SAE

But, how can we find these multilingual features in a SAE?

Much of the recent work regarding SAEs and mechanistic interpretability, has been about either scaling up the models to make them more powerful (Templeton et al., 2024) (Gao et al., 2024), finding techniques to make the models better at reconstructing the input (Rajamanoharan, Lieberum, et al., 2024), or using the learned features to find interesting circuits in the model (Wang et al., 2022). Many of such endevours end up always finding some multilingual features, but they are not the main focus of the work, nor are they systematically searched for.

In this work, we present a systematic, data-driven process to generate a list of candidate bilingual features from a GemmaScope SAE. We define a bilingual interpretability score for each feature, which is dependent on a dataset of equivalent English-Spanish sentences. We then rank the features based on this score and analyze them. Finally, we discuss the potential for extending this methodology to include more that 2 languages.

Methodology

The driving intuition behind our methodology is that, inspite of changes in tokenization, word order, general linguistic structure, and even the distribution of feature logits across languages, for a feature to be bilingual, it is necessary that it circumvents these differences and be activated by the same or similar sentences in both languages.

Such condition may not be sufficient, but as we will see, it is a good starting point to find bilingual features in a SAE.

In this section, we will describe the specific methodology that arises from this intuition, which consists of three main steps:

  1. Data Collection: We gather a dataset of equivalent English-Spanish sentences.
  2. Feature Extraction: We extract the features from the SAE for each sentence in the dataset.
  3. Bilingual Interpretability Score: We define a score that measures how similar the activations of a feature are across languages.

Basic Setup

For our experiments, we used Gemma 2-2B as our language model on its base pretrained version without any intruction tuning (See Figure 1). We focused our experiments on a single SAE from the GemmaScope collection of open-source SAEs (Lieberum et al., 2024), specifically, the one with id gemma-scope-2b-pt-res-canonical/layer_20/width_16k/canonical. This SAE has 16k features and is trained on the residual stream of the 20th layer of the model. It is the smallest version of this particular hook point, and the choice for its size was made purely for computational reasons.

The 2B version of the Gemma 2 models has 26 layers (Team et al., 2024), so a SAE trained on the 20th residual stream is expected to have learned more abstract features than earlier layers. Moreover, we decided to use the residual stream instead of the attention heads because it is an information bottleneck where not only the prior attention head writes to, but also all the later ones, so one should expect that the features learned in this specific point are more abstract and general than those inside the attention mechanism (Elhage et al., 2021).

For our bilingual dataset, we used a small sample of the OPUS Books dataset (Tiedemann, 2012), with equivalent English-Spanish sentences.

Feature Extraction

To extract the features from the SAE, we used the HookedSAETransformer class from the SAELens library (Bloom and Chanin, 2024). This class allows us to hook our SAE to a given language model, and cache the activations of the SAE for a given set of inputs.

We ran the HookedSAETransformer on the English and Spanish sample pairs, and stored the activations using the datasets library from Hugging Face. This data is publicly available at hugging face hub under the name diegomezp/gemmascope_bilingual_activations. It contains not only the activations of the SAE for the sample pairs, but also the token ids of each sentence.

Show source
import os
from huggingface_hub import login
from datasets import load_dataset
from dotenv import load_dotenv
import torch
import plotly.graph_objects as go
import numpy as np
import warnings

warnings.filterwarnings("ignore")

# load the environment variables
load_dotenv(override=True)

# login to hugging face
hf_token = os.getenv("HF_TOKEN")
login(token=hf_token, add_to_git_credential=True)

# download the dataset
sample_ds = load_dataset("diegomezp/gemmascope_bilingual_activations").with_format(
    "torch"
)
sample_ds = sample_ds["train"]

activation_tensor = torch.nested.nested_tensor(
    sample_ds["sae_features"]
).to_padded_tensor(0.0)


def get_single_lang_statistics(activation_tensor: torch.tensor) -> dict:
    """
    Input:
      activation_tensor (torch.tensor float32): Tensor of size (samples, tokens, features)

    Output:
      (dict) : {
          "mean": {
            "value": float,
            "series": tensor size(|features|)
          },
          "q_0.05": # same structure as mean,
          "q_0.25": # same,
          "q_0.50": # same,
          "q_0.75": # same,
          "q_0.95": # same,
        }
    """
    s, t, f = activation_tensor.size()
    # Get quantils only for those logits > 0
    activation_logits = activation_tensor[activation_tensor > 0]
    mean_act = activation_logits.mean()
    quantiles = [0.05, 0.25, 0.5, 0.75, 0.95]
    quantiles_values = torch.quantile(activation_logits, torch.tensor(quantiles))

    thresholds = {"mean": mean_act}
    thresholds.update({f"q_{q}": v for q, v in zip(quantiles, quantiles_values)})
    response = dict()
    max_activations = activation_tensor.max(dim=1).values  # size (s, f)

    for name, threshold in thresholds.items():
        response[name] = dict(value=threshold.item())
        final_activations = (
            (max_activations > threshold).to(float).mean(dim=0)
        )  # size f
        response[name]["series"] = final_activations.sort(descending=True).values
    return response


def get_activation_statistics(activation_tensor: torch.tensor) -> dict:
    """
    Both globally and for each language dimension we will get as much as 4 series
    of size |features|. Each of those series will represent what percentage of
    the samples had at least one activation of a given feature.
    The ordering of the features in such tensor will also be returned, and we will
    use quantiles and mean for setting the activation threshold:

    Input:
      activation_tensor (torch.tensor): Tensor of size (samples, languages, tokens, features)

    Response:
      (dict) : {
        "stats: {
            "global": {
              "mean": {
                "value": float,
                "series": tensor size(|features|)
              },
              "q_0.05": # same structure as mean,
              "q_0.25": # same,
              "q_0.50": # same,
              "q_0.75": # same,
              "q_0.95": # same,
            },
            "lang_0": # Same structure as before,
            ...
            "lang_n": # Same as before
          }
        }
    """
    assert len(activation_tensor.size()) == 4, (
        "ActivationTensor must have 4 dims" "(samples, languages, tokens, features)"
    )
    s, l, t, f = activation_tensor.size()
    response = dict()
    response["stats"] = dict()
    response["stats"]["global"] = get_single_lang_statistics(
        activation_tensor.reshape(-1, t, f)
    )

    for idx in range(l):
        response["stats"][f"lang_{idx}"] = get_single_lang_statistics(
            activation_tensor[:, idx, :, :]
        )

    return response


activation_stats = get_activation_statistics(
    activation_tensor[:, :, 1:, :]
)  # Ignoring BOS token


# Determine the number of groups and tensors
num_groups = len(activation_stats["stats"])
num_tensors = len(activation_stats["stats"]["global"])

titles = {
    "global": "Percentage of samples that each feature activated (ordered)",
    "lang_0": "Percentage of activated Spanish samples (ordered)",
    "lang_1": "Percentage of activated English samples (ordered)",
}


def print_activation_stats(group_name, group_data):
    fig = go.Figure()
    for series_name, series_data in group_data.items():
        fig.add_trace(
            go.Scatter(
                x=np.arange(series_data["series"].size(0)),  # X-axis: indices of tensor
                y=series_data["series"].numpy(),  # Y-axis: tensor values
                mode="lines",  # Line plot
                name=f"{series_name}={series_data['value']:.2f}",  # Name for the legend
                showlegend=True,
            )
        )

    # Update layout
    fig.update_layout(
        title=titles[group_name],
        xaxis_title="Ordered SAE Features",
        yaxis_title="Percentage of Activating Samples",
        legend_title="Activation Threshold",
        template="plotly_white",
        xaxis_type="log",
    )

    fig.show()


# print_activation_stats("global", activation_stats["stats"]["global"])
Token is valid (permission: write).
Your token has been saved in your configured git credential helpers (osxkeychain).
Your token has been saved to /Users/diego.gomez/.cache/huggingface/token
Login successful
Show source
print_activation_stats("lang_0", activation_stats["stats"]["lang_0"])
Show source
print_activation_stats("lang_1", activation_stats["stats"]["lang_1"])

On the graphs above: For each feature, we can see the percentage of dataset samples for which they activated. By setting ever increasing thresholds for the activation logits, we can also see by how much they activated. We used the 5%, 25%, 50%, 75%, and 95% quantiles, as well as the mean activation as thresholds.

Note that the x-axis is not showing feature ids, but the features ordered by the percentage of samples that activated them.

The BOS Token (Begining of Sequence) is not considered in the statistics, as it is a special token against which the SAEs features have an imnense positive activation bias.

Ignoring the BOS token will be a recurring theme in the rest of the analysis.

Apart from the exponential decay in the percentage of samples that activate a feature, we can see that both spanish and english samples have a very similar overall distribution of activations, with the english subset having slightly higher activation intensities. This is a good sign that bilingual features are not only present but probably common in this particular SAE.

Billingual Interpretability Score

We define a bilingual interpretability score by having separate bilingual and interpretability components. The bilingual component is a measure of how similar the activations of a feature are across languages, while the interpretability component is a measure of activation frequency, that in this particular scenario, is a proxy for how easy it would be to interpret a feature.

Bilingual Loss

Let’s say we have our dataset \(D\) which is a disjoint union of \(D_{es}\) and \(D_{en}\), with, clearly, \(|D_{es}| = |D_{en}| = n\), i.e, english and spanish datasets have the same size.

Let \((d_{es}^k, d_{en}^k)\) be the natural k-th pairing of spanish-english samples.

Let \(d\_sae\) be the number of features we have for our SAE

Let \(f\) be the composition of our model up to the hooked layer and our SAE encoder such that \(f(d) \in R^{ctx\_size}\times R^{d\_sae}\) is the SAE feature activations for dataset example \(d\), with \(ctx\_size\) being the number of tokens of example \(d\).

Then, for each feature \(F_i\) with \(i \in [0, d\_sae-1]\), we can define a bilingual scoring function \(BF_{D}(\cdot)\) by converting the activation vector for each language into a distribution with a softmax and then applying some symmetric measure of distance like the Jensen-Shannon divergence. Formally:

We chose the Jensen-Shannon divergence instead of things like cross-entropy because it is symmetric and always positive. The fact that it is symmetric is important, since a we should be comparing both languages without preference for one or the other. Also, this divergence is naturaly extendable to more than two languages, which is a feature we might want to explore in the future.

Let

\[ q^i_{lang} := [max(f(d_{lang}^0)_i), \cdots, max(f(d_{lang}^n)_i)]^T \in R^{d\_sae} \]

With \(lang \in \{en, es\}\) and \(i \in [0, d\_sae-1]\) being the corresponding maximum activation logit-vector of the feature \(F_i\) for the language \(lang\).

Then:

We ran into numercial precision problems with this definition, so, some changes were made which are detailed in Section 5.1. We do not think this changes the overall interpretation of the score, nor the results of the analysis, so we left the original, easier to understand, definition in the main text.

\[ BF_{D}(F_i) := JSD(softmax(q^i_{es}), softmax(q^i_{en})) \]

With \(JSD\) being the Jensen-Shannon divergence given by:

\[ JSD(p, q) := \frac{1}{2} D_{KL}(p | \textbf{M}) + \frac{1}{2} D_{KL}(q | \textbf{M}) \]

Where \(\textbf{M}\) is the mixture distribution \(\frac{1}{2}(p + q)\) and \(D_{KL}\) is the standar Kullback-Leibler divergence

Random idea: Can we force multilinguality in the SAEs by training them with multilingual samples and using the above loss (batched clearly) as complementary to the reconstruction and sparsity losses?

The prior metric gives us zeros when our feature is perfectly bilingual (activates not only in the same samples but with the same magnitude). This definition has a clear drawback, things like completely dead features (that produce zeros for all tokens) are perfectly bilingual, so we ended up wiltering them out for much of the analysis.

Ease-of-Interpretability Loss

We did not want to deal with with those features that activated for a large portion of the dataset, since they are not only hard to interpret, but also probably not very useful for our purposes. We borrowed ideas from Information Retrieval and defined an Inverse Document Activation Frequency (IDAF) for each feature, which is the inverse of the percentage of samples that activate a feature. Formally:

\[ idaf_{D}(F_i) := \frac{|D|}{\sum_{d \in D} \textbf{1}_{max(f(d)) > 0}} \]

Note that a feature with zero activations in the whole set will cause a division by zero error. That is another reason we need to filter those out first.

And the final score for a given feature \(F_i\) will be:

\[ Bilingual\_Interpretability_{D}(F_i) := BF_{D}(F_i) + \beta \cdot idaf_{D}(F_i) \]

The \(\beta\) parameter is a hyperparameter that we tuned specificaly so the \(idaf\) term only acted as a reranker for the top features that were already good in the bilingual component. The specifics of this tuning are detailed in Section 5.2.

Feature Interpretation

When we come across a candidate Bilingual feature, the final step is to try to interpret it. There are many tools for this task, and one usually can not be 100% sure of the interpretation of a feature, nevertheless, we can use the Feature Dashboards and maximum activation samples from our dataset, to get a sense of what is the feature activating for.

For example, the following neuronpedia dashboard, shows top activations for the feature with id 10194, which seems to be related to the concept of United States:

Show source
from IPython.display import IFrame

# get a random feature from the SAE
feature_idx = 2647

html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
neuronpedia_sae, neuronpedia_id = "gemma-2-2b/20-gemmascope-res-16k".split("/")

def get_dashboard_html(feature_idx=0, sae_release = neuronpedia_sae, sae_id=neuronpedia_id):
    return html_template.format(sae_release, sae_id, feature_idx)

US_FT_IDX = 10194
IFrame(get_dashboard_html(US_FT_IDX), width=1200, height=600)

For our use case, neuronpedia has a major limitation, which is that we can not filter the samples by language.

To overcome this limitation, we used the circuitsvis library to visualize the top activations samples from our own dataset, which we can easily filter by language (Cooney and Nanda, 2023). For example, the following, shows the top activations for the previously shown feature with index 10194, but now we can see the activations for both languages:

Show source
from circuitsvis.tokens import colored_tokens


def get_tokens_and_acts(
    idx, ft_idx, ds=sample_ds, activations=activation_tensor, include_bos=False
):
    start_idx = 0 if include_bos else 1
    str_tokens_es = ds[idx]["str_tokens"]["es"][start_idx:]
    str_tokens_en = ds[idx]["str_tokens"]["en"][start_idx:]

    token_act_es = activations[idx, 0, start_idx : (len(str_tokens_es) + 1), ft_idx]
    token_act_en = activations[idx, 1, start_idx : (len(str_tokens_en) + 1), ft_idx]

    t = ["<b>EN:</b>  "] + str_tokens_en + ["      <b>ES:</b>  "] + str_tokens_es
    a = [0] + token_act_en.tolist() + [0] + token_act_es.tolist()

    return colored_tokens(t, a)


greatest_act = (
    activation_tensor[:, 1, 1:, US_FT_IDX].max(dim=-1).values.sort(descending=True)
)
get_tokens_and_acts(greatest_act.indices[0].item(), US_FT_IDX)
Show source
get_tokens_and_acts(greatest_act.indices[1].item(), US_FT_IDX)

Finally, we can get a general sense of what our score is actually looking for by plotting the logits of a given feature across the whole dataset.

Show source
import matplotlib.pyplot as plt

idx = US_FT_IDX

display_ctxt_size = 70
fig_es = activation_tensor[:, 0, 1:, idx][:, :display_ctxt_size].numpy()
fig_en = activation_tensor[:, 1, 1:, idx][:, :display_ctxt_size].numpy()

fig, axes = plt.subplots(1, 2, figsize=(10, 5))  # 1 row, 2 columns

im_en = axes[0].imshow(fig_en, cmap="hot", interpolation="nearest", aspect="auto")
im_es = axes[1].imshow(fig_es, cmap="hot", interpolation="nearest", aspect="auto")
fig.colorbar(im_en, ax=axes[0])
fig.colorbar(im_es, ax=axes[1])
axes[0].set_title(f"Activation Feature {idx} for EN")
axes[0].set_xlabel("Tokens")
axes[0].set_ylabel("Samples")
axes[1].set_title(f"Activation Feature {idx} for ES")
axes[1].set_xlabel("Tokens")
axes[1].set_ylabel("Samples")
plt.show()

Results

In this section we summarize the main results of our analysis. We will show the top features according to the Bilingual Interpretability Score toghether with their possible interpretations.

Table 1 shows the top 11 ranked features according to the Bilingual Interpretability Score (lower scores are better). We can see how these top ranked features are indeed activating for concepts in both languages.

Position Feature ID Possible Interpretation BI Loss Value
0 6530 Not clearly interpretable 4.442881974691051e-08
1 2009 References to Quality and Calidad 4.4428819793659613e-08
2 7502 Related to Afghanistan / Afganistán: places, military organizations (e.g., Taliban) 4.442882938285578e-08
3 4275 Feature for the name Andrew and the Spanish equivalent Andrés 4.442882967726952e-08
4 2760 References to Price / Precio 4.442889617501746e-08
5 10194 Feature referring to United States / Estados Unidos / US / America 4.442939963824309e-08
6 13963 References to scientific measurements 4.443070022280122e-08
7 2054 References to February / Febrero and the number 2 in the context of months 4.443856851443512e-08
8 4762 References to Degrees / Grados Celsius / Atmospheric temperature 4.444231090244995e-08
9 14036 Not clearly interpretable 4.4451682604606074e-08
10 12412 References to Education / Educar / Instruir 4.445858874005918e-08
Table 1: Top features according to the Bilingual Interpretability Score. The position is taken after filtering out the dead features and all the scoring calculation was done ignoring the BOS token of each sample

Also, the actual scores seem to be really small numbers, with little separation between them. To get a sense of the distribution of the scores, we can plot them as a line plot:

Show source
import torch.nn.functional as F


def ur_kldiv(input: torch.tensor, target: torch.tensor, dim=0) -> torch.tensor:
    return target * (target.log() - input.log())


def JSD(p_logits: torch.tensor, q_logits: torch.tensor, dim=0) -> torch.tensor:
    # We will normalize them for numerical reasons
    p, q = p_logits.to(torch.float64), q_logits.to(torch.float64)
    p = p * 100 / p.max()
    q = q * 100 / q.max()

    p = p.softmax(dim=dim)
    q = q.softmax(dim=dim)
    m = 0.5 * (p + q)

    jsd = 0.5 * (ur_kldiv(p, m).mean(dim=0) + ur_kldiv(q, m).mean(dim=0))
    return jsd


def Idaf(activations):
    d, l, t, f = activations.size()
    idaf = (d * l) / (activations.view(-1, t, f).max(dim=1).values > 0).sum(0)
    idaf = torch.where(((d * l) / idaf) < 10, 1.0, idaf)
    return -idaf


# We will get the indeces of the features that are dead
es_dead_features = (activation_tensor[:, 0, 1:, :].max(dim=1).values > 0).sum(0) == 0
en_dead_features = (activation_tensor[:, 1, 1:, :].max(dim=1).values > 0).sum(0) == 0
dead_features = es_dead_features | en_dead_features

# Activations for the first and second language ignoring the BOS token
es_acts = (
    activation_tensor[:, 0, 1:, :].max(dim=1).values
)  # From 1: in third dim to filter-out BOS
en_acts = activation_tensor[:, 1, 1:, :].max(dim=1).values

jsd = JSD(es_acts, en_acts)
idaf = Idaf(activation_tensor[:, :, 1:, :])

# Calculate the beta hyperparameter
beta_num = -(
    jsd[~dead_features].sort(descending=False).values[300] - jsd[~dead_features].min()
).item()
beta = beta_num / (idaf[~dead_features].max() - idaf[~dead_features].min())

# Calculate the BI score
BI = jsd + beta * idaf

# We will sort the features by the BI score and filter out the dead features
unfiltered = BI.sort(descending=False)
mask = ~dead_features
idx_order = unfiltered.indices[mask[unfiltered.indices]]
idx_value = unfiltered.values[mask[unfiltered.indices]]

# We will plot the distribution of the BI score

fig = go.Figure()
fig.add_trace(
    go.Scatter(x=np.arange(idx_value.size(0)), y=idx_value.numpy(), mode="lines")
)
fig.update_layout(
    title="Distribution of the <b>Bilingual Interpretability Score</b> on alive features",
    xaxis_title="SAE features post filtering of dead features",
    yaxis_title="BI loss value",
    template="plotly_white",
)
fig.show()

Our score seems to behave in a clear exponential fashion, where the top features are close to 7 orders of magnitude better than the worst ones.

Discussion

In this work, we presented a systematic, data-driven process to generate a list of candidate bilingual features from a GemmaScope SAE. We defined a bilingual interpretability score for each feature, which is dependent on a dataset of equivalent English-Spanish sentences. We then ranked the features based on this score and analyzed them.

We showed how this otherwise simple approach, can be used to find features that are not only bilingual, but also interpretable.

Crucially, we showed that we use data-based approaches to actively look for interpretable feattures with some specific properties, in this case, bilingualism. This is a clear departure from the usual approach of finding features that are interpretable by chance, and it is a step towards a more systematic and data-driven approach to mechanistic interpretability.

Limitations

This was a rather time and resource constrained project, so there is a lot of room for improvement. For example, the dataset sample size we used was very small, and the SAE we used was not the most powerful one. We also did not thoroughly explore the effect of the \(\beta\) hyperparameter in the final results, nor did we explore the effect of the dataset size on Bilingual Interpretability Scores.

Also, numerical precision problems with the Jensen-Shannon divergence, forced us to make some changes to the original definition of the score, which may have affected the results, and which ideally should be further and more rigorously explored.

Moreover, the idaf score is a very simplistic choice base only on the activation frequency of the features, which may not be the best proxy for interpretability. This is a very naive approach, and it would be interesting to explore more sophisticated ways of numerically asessing the interpretability of a feature.

Finally, the interpretation step itself is very manual and subjective, and it would be interesting to explore ways to automate this process, or at least to make it more systematic.

Future Work

There are many ways in which this work can be extended. For example, we could explore the effect of the dataset size on the Bilingual Interpretability Score, or we could explore the effect of the \(\beta\) hyperparameter on the final results.

Also, we pointed out that the use of the Jensen-Shannon divergence as a measure of similarity between the activations of a feature across languages can be extended to more than two languages, opening the possibility of finding features that are multilingual in a more general sense.

Our ideas can also be used in the training phase of the SAEs, to force them to learn multilingual features, and to explore the effect of this on the final performance of the model. By setting the Jensen-Shannon divergence as a complementary loss to the reconstruction and sparsity losses, and by training the model with a portion of multilingual samples, we could force the model to learn features that are not only interpretable but also multilingual.

Finally, we could explore how the choice of the hook point in the model affects the final results, and the distribution of multilingual features in the SAE.

Acknowledgements

This work is the final project for the course on AI Safety and Alignment from BlueDot Impact. I would like to thank the BlueDot team for their support and guidance throughout the course. In particular, I would like to thank my facilitator, Aaron Scher, for his insightful feedback and encouragement. I would also like to thank my peers for their feedback and support in the project phase.

Apendix

Considerations for Better Numerical Properties of the JSD

The calculation of the Jensen-Shannon divergence as we defined it, has some numerical problems. The main one is that the softmax function can produce very small numbers, which can lead to numerical instability when calculating the logarithm needed to compute the Kullback-Leibler divergence. Remember, the Kurback-Leibler divergence for two discrete probability distributions \(P\) and \(Q\) is defined as:

\[ D_{KL}(P | Q) = \sum_{i} P(i) \log \left(\frac{P(i)}{Q(i)}\right) \]

The problem in here, is that when \(P(i)\) is very close to zero, the logarithm can be a very large negative number, easly leading to overflow and -inf values. Even though, theoretically, the Jensen-Shannon divergence should be always well defined and positive, in practice, we can run into these problems that lead to nan values in the final score.

Sadly, the raw logits of our features can be very large for the maximum activations, but also very small for the rest of the activations, such that it is almost certain that we will run into this problem after applying the softmax, specifically to those features where the maximum activation is very large in relation to the mean activation.

To overcome this issue, we smoothed the raw logit distributions before applying the softmax. In our case, we simply chose a threshold value of 100, and then normalized all the logits such as the maximum logit was 100. Formaly, before applying the softmax, we did:

\[ q^i_{lang} := \frac{100}{max(f(d_{lang}^0)_i)} \cdot [max(f(d_{lang}^0)_i), \cdots, max(f(d_{lang}^n)_i)]^T \in R^{d\_sae} \]

This simple trick, allowed us to avoid the numerical problems we were facing, and to calculate the Jensen-Shannon divergence without any issues.

We chose 100, since we noted that this will result in numbers close to the maximum precision of float64, making the best compromise between numerical stability and precision.

On the \(\beta\) Hyperparameter

The \(\beta\) hyperparameter was chosen to, at most, be able to change the order of the top 300 features. In other words, we wanted the range between the minimum and maximum values of the \(idaf\) score to be no more than the range between the minimum and the 300th value of the Jensen-Shannon divergence score. Formally, we defined \(\beta\) as:

\[ \beta = - \frac{JSD_{300} - JSD_{min}}{IDAF_{max} - IDAF_{min}} \]

Where \(JSD_{300}\) is the 300th value of the Jensen-Shannon divergence score, and \(JSD_{min}\) is the minimum value of the Jensen-Shannon divergence score. \(IDAF_{max}\) and \(IDAF_{min}\) are the maximum and minimum values of the Inverse Document Activation Frequency score, respectively.

Also, the negative sign is there to make the top features have lower scores.

References

Bloom, J. and Chanin, D. (2024) SAELens. https://github.com/jbloomAus/SAELens.
Bricken, T., Templeton, A., Batson, J., et al. (2023) Towards monosemanticity: Decomposing language models with dictionary learning. Transformer Circuits Thread, 2. Anthropic.
Cooney, A. and Nanda, N. (2023) CircuitsVis. https://github.com/TransformerLensOrg/CircuitsVis.
Cunningham, H., Ewart, A., Riggs, L., et al. (2023) Sparse autoencoders find highly interpretable features in language models. arXiv preprint arXiv:2309.08600.
Elhage, N., Nanda, N., Olsson, C., et al. (2021) A mathematical framework for transformer circuits. Transformer Circuits Thread.
Gao, L., Tour, T. D. la, Tillman, H., et al. (2024) Scaling and evaluating sparse autoencoders. arXiv preprint arXiv:2406.04093.
Lieberum, T., Rajamanoharan, S., Conmy, A., et al. (2024) Gemma scope: Open sparse autoencoders everywhere all at once on gemma 2. arXiv preprint arXiv:2408.05147.
Olah, C., Cammarata, N., Schubert, L., et al. (2020) Zoom in: An introduction to circuits. Distill, 5, e00024–001.
Rajamanoharan, S., Conmy, A., Smith, L., et al. (2024) Improving dictionary learning with gated sparse autoencoders. arXiv preprint arXiv:2404.16014.
Rajamanoharan, S., Lieberum, T., Sonnerat, N., et al. (2024) Jumping ahead: Improving reconstruction fidelity with JumpReLU sparse autoencoders. arXiv preprint arXiv:2407.14435.
Team, G., Riviere, M., Pathak, S., et al. (2024) Gemma 2: Improving open language models at a practical size. arXiv preprint arXiv:2408.00118.
Templeton, A., Conerly, T., Marcus, J., et al. (2024) Scaling monosemanticity: Extracting interpretable features from claude 3 sonnet. Transformer Circuits Thread. Available at: https://transformer-circuits.pub/2024/scaling-monosemanticity/index.html.
Tiedemann, J. (2012) Parallel data, tools and interfaces in OPUS. In: Proceedings of the eighth international conference on language resources and evaluation (LREC’12) (eds. N Calzolari, K Choukri, T Declerck, et al.), Istanbul, Turkey, May 2012, pp. 2214–2218. European Language Resources Association (ELRA). Available at: http://www.lrec-conf.org/proceedings/lrec2012/pdf/463_Paper.pdf.
Wang, K., Variengien, A., Conmy, A., et al. (2022) Interpretability in the wild: A circuit for indirect object identification in gpt-2 small. arXiv preprint arXiv:2211.00593.