jinjieyuan's picture
Upload model
c1052e6
metadata
language: en
license: apache-2.0

Shears Model Card: shears-llama-7b-50-commonsense-heuristic

The heuristic subnetwork discovered from the super-network fine-tuned on LLaMA-7B with some commonsense reasoning datasets using Shears.

Model Details

Information

Adapter Configuration

  • LoRA rank: 32
  • LoRA alpha: 64
  • LoRA target modules: q_proj, k_proj, v_proj, up_proj, gate_proj, down_proj
  • LoRA rank search space: [32, 24, 16] (for each LoRA module)

Training Hyperparameters

  • Batch size: 16
  • Learning rate: 3e-4
  • Epoch: 3

Training Data

Unified commonsense reasoning dataset: commonsense_170k.json.

Evaluation Data

BoolQ, PIQA, SIQA, HellaSwag, WinoGrande, ARC-e, ARC-c, OBQA.

How to use

Use our modified PEFT library (apply patch):

git clone https://github.com/huggingface/peft.git
pushd peft && git checkout v0.5.0 && git apply --ignore-space-change --ignore-whitespace peft-modifications-for-shears-inference-usage.patch && pip install -e . && popd
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer

def generate_prompt(instruction):
    return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. 

                    ### Instruction:
                    {instruction}

                    ### Response:
                    """

base_model_path = "shears-llama-7b-50-commonsense-heuristic/base_model"
adapter_model_path = "shears-llama-7b-50-commonsense-heuristic/adapter_model"
base_model = AutoModelForCausalLM.from_pretrained(base_model_path)
model = PeftModel.from_pretrained(base_model, adapter_model_path)
model.eval()

non_zero_params = sum([(param.data != 0).sum().item() for _, param in model.named_parameters()])
print(f"Number of all non-zero parameters: {non_zero_params}")

tokenizer = AutoTokenizer.from_pretrained(base_model_path)
tokenizer.pad_token_id = 0

instruction = "Please choose the correct answer to the question: A cactus stem is used to store\n\nAnswer1: fruit "
        "Answer2: liquid Answer3: food Answer4: spines\n\nAnswer format: answer1/answer2/answer3/answer4"
prompt = generate_prompt(instruction)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(model.device)
with torch.no_grad():
    generation_output = model.generate(
        input_ids=input_ids,
        return_dict_in_generate=True,
        output_scores=True,
        max_new_tokens=256,
        use_cache=True,
        num_beams=4,
    )
  s = generation_output.sequences[0]
  output = tokenizer.decode(s)
print(output)

Evaluation Results

Model Sparsity BoolQ PIQA SIQA HellaSwag WinoG ARC-e ARC-c OBQA Average
ChatGPT - 73.1 85.4 68.5 78.5 66.1 89.8 79.9 74.8 77.0
LLaMA-7B-LoRA - 68.9 80.7 77.4 78.1 78.8 77.8 61.3 74.8 74.7
LLaMA-7B-Shears 50% 67.3 79.1 77.5 73.3 77.7 74.4 57.9 72.8 72.5

Model Sources

Citation

@article{munoz2024shears,
  title = {Shears: Unstructured Sparsity with Neural Low-rank Adapter Search},
  author={J. Pablo Munoz and Jinjie Yuan and Nilesh Jain},
  journal={},
  year={2024}
}

License

Apache-2.0