bitLinear-phi-1.5 / README.md
Mrw33554432's picture
Update README.md
bf9d895 verified
metadata
license: mit
datasets:
  - wikipedia

BitLinear-phi-1.5

BitLinear-phi-1.5 is a model trained partially using the method described in The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits.

Notice: Our BitLinear layer will only apply 1-bit quantization to the weight

Other components (RMSnorm, activation quant) in the paper is discarded.

Idea behind: The major contribution in their paper is introduced a valid binary weight quantization, we don't want to mix it with other components to make it difficult to evaluate the major part.

The model structure is from phi-1.5, with all linear layers except lm_head replaced with our custom BitLinear layer.

It was trained on a small subset of the wikipedia dataset dataset, for research validation purpose only.

dataset = load_dataset("wikipedia", "20220301.en")
dataset = dataset['train'].select(range(int(1e5)))

Please notice the kernel is not optimzed for 1-bit matrix yet.

The model is trained on a 3090(24GB) for 16 hours.

For faster(3x) inference, check https://github.com/Mrw33554432/Bitlinear4HF and install custom kernel

For training code, check https://github.com/Mrw33554432/Bitlinear4HF.

The training code should be compatible with most of the LLMs in huggingface.

Using pretrained model weight (normal models) for training will not work due to gradient explosion.

Sample inference code (slow)

import torch
from replace_hf import replace_linear_in_hf
from transformers import AutoModelForCausalLM, AutoTokenizer


def quick_test(model, tokenizer, prompt: str):
    # Encode the inputs
    inputs = tokenizer.encode(prompt, return_tensors="pt")

    # Generate outputs
    outputs = model.generate(inputs, max_length=64)

    # Decode and print the outputs
    print(tokenizer.decode(outputs[0]))


torch.set_default_device("cuda")

tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("Mrw33554432/bitLinear-phi-1.5", trust_remote_code=True, torch_dtype=torch.float16)

print(model)
# Replace Linear layers with BitLinear
replace_linear_in_hf(model, keep_param=True)
print(model)

quick_test(model, tokenizer, prompt="Tom is the")