ReFT

1. AxBench

Live Demo: https://huggingface.co/spaces/pyvene/AxBench-ReFT-r1-16K

AxBench evaluates interpretability methods in terms of concept detection and model steering. AxBench releases two supervised dictionary learning methods that outperforms existing methods including SAEs. These dictionaries contain 1D subspaces that map to high-level concepts.

2. What is gemma-reft-2b-it-res?

It is a single dictionary of subspaces for 16K concepts and serves as a drop-in replacement for SAEs.

  • gemma-: Refer to Gemma 2 models
  • reft- : The dictionary learning model is trained by using representation finetuning (ReFT) (see ReFT paper for details)
  • 2b-it-: The dictionary is for Gemma 2 2B instruction-tuning model
  • res : The dictionary is trained on the model's residual stream.
  • We release the weights as well as the annotated concepts for all subspaces.

3. How can I use these dictionaries straight away?

from huggingface_hub import hf_hub_download
import pyvene as pv

# Create an intervention.
class Encoder(pv.CollectIntervention):
    """An intervention that reads concept latent from streams"""
    def __init__(self, **kwargs):
        super().__init__(**kwargs, keep_last_dim=True)
        self.proj = torch.nn.Linear(
                self.embed_dim, kwargs["latent_dim"], bias=False)
    def forward(self, base, source=None, subspaces=None):
        return torch.relu(self.proj(base))

# Loading weights
path_to_params = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/weight.pt")
encoder = Encoder(embed_dim=params.shape[0], latent_dim=params.shape[1])
encoder.proj.weight.data = params.float()

# Mount the loaded intervention.
pv_model = pv.IntervenableModel({
   "component": f"model.layers[20].output",
   "intervention": encoder}, model=model)

# use pv_model just as other torch model, and you can collect subspace latent.
prompt = "Would you be able to travel through time using a wormhole?"
input_ids = torch.tensor([tokenizer.apply_chat_template(
    [{"role": "user", "content": prompt}], tokenize=True, add_generation_prompt=True)]).cuda()
acts = pv_model.forward(
    {"input_ids": input_ids}, return_dict=True).collected_activations[0]

4. Point of Contact

Point of contact: Zhengxuan Wu or Aryaman Arora

Contact by email:

{wuzhengx, aryamana}@stanford.edu

5. Citation

Paper:

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Unable to determine this model's library. Check the docs .

Spaces using pyvene/gemma-reft-2b-it-res 2

Collection including pyvene/gemma-reft-2b-it-res