Finetuned Gemma-2-2B for generating subspaces given any natural language descriptions for Gemma-2-9B-it

In the AxBench paper, we finetuned a subspace generator. The subspace generator is a hyper-network that will generate a subspace for you given a concept description in natural language. High-quality subspace generator can bypass all dictionary training!

How to use the subspace generator?

import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer

class RegressionWrapper(torch.nn.Module):
    def __init__(self, base_model, hidden_size, output_dim):
        super().__init__()
        self.base_model = base_model
        self.regression_head = torch.nn.Linear(hidden_size, output_dim)

    def forward(self, input_ids, attention_mask):
        outputs = self.base_model.model(
            input_ids=input_ids, 
            attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True
        )
        last_hiddens = outputs.hidden_states[-1]
        last_token_representations = last_hiddens[:, -1]
        preds = self.regression_head(last_token_representations)
        preds = F.normalize(preds, p=2, dim=-1)
        return preds

base_model = AutoModelForCausalLM.from_pretrained(
    f"google/gemma-2-2b", torch_dtype=torch.bfloat16)
base_tokenizer = AutoTokenizer.from_pretrained(
    f"google/gemma-2-2b", model_max_length=512)

subspace_gen = RegressionWrapper(
    base_model, hidden_size, output_dim).bfloat16().to("cuda")
subspace_gen.load_state_dict(torch.load('model.pth'))

your_new_concept = "terms related to Stanford University"

inputs = base_tokenizer(your_new_concept, return_tensors="pt").to("cuda")
input_ids, attention_mask = inputs["input_ids"], inputs["attention_mask"]
subspace_gen(input_ids, attention_mask)[0]
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.

Collection including pyvene/gemma-reft-9b-it-res-generator