AxBench Release
Collection
Open supervised dictionary learning models and datasets for Gemma 2 2B and 9B instruction-tuned models.
•
9 items
•
Updated
•
2
Live Demo: https://huggingface.co/spaces/pyvene/AxBench-ReFT-r1-16K
AxBench evaluates interpretability methods in terms of concept detection and model steering. AxBench releases two supervised dictionary learning methods that outperforms existing methods including SAEs. These dictionaries contain 1D subspaces that map to high-level concepts.
gemma-reft-2b-it-res
?
It is a single dictionary of subspaces for 16K concepts and serves as a drop-in replacement for SAEs.
gemma-
: Refer to Gemma 2 modelsreft-
: The dictionary learning model is trained by using representation finetuning (ReFT) (see ReFT paper for details)2b-it-
: The dictionary is for Gemma 2 2B instruction-tuning modelres
: The dictionary is trained on the model's residual stream.from huggingface_hub import hf_hub_download
import pyvene as pv
# Create an intervention.
class Encoder(pv.CollectIntervention):
"""An intervention that reads concept latent from streams"""
def __init__(self, **kwargs):
super().__init__(**kwargs, keep_last_dim=True)
self.proj = torch.nn.Linear(
self.embed_dim, kwargs["latent_dim"], bias=False)
def forward(self, base, source=None, subspaces=None):
return torch.relu(self.proj(base))
# Loading weights
path_to_params = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/weight.pt")
encoder = Encoder(embed_dim=params.shape[0], latent_dim=params.shape[1])
encoder.proj.weight.data = params.float()
# Mount the loaded intervention.
pv_model = pv.IntervenableModel({
"component": f"model.layers[20].output",
"intervention": encoder}, model=model)
# use pv_model just as other torch model, and you can collect subspace latent.
prompt = "Would you be able to travel through time using a wormhole?"
input_ids = torch.tensor([tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}], tokenize=True, add_generation_prompt=True)]).cuda()
acts = pv_model.forward(
{"input_ids": input_ids}, return_dict=True).collected_activations[0]
Point of contact: Zhengxuan Wu or Aryaman Arora
Contact by email:
{wuzhengx, aryamana}@stanford.edu
Paper: