bitnet_b1_58-3B_quantized / quantization.py
kousw's picture
Upload 21 files
29964ce verified
raw
history blame contribute delete
886 Bytes
import argparse
import torch
from modeling_bitnet import BitnetForCausalLM
from tokenization_bitnet import BitnetTokenizer
torch.set_grad_enabled(False)
parser = argparse.ArgumentParser()
parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str)
parser.add_argument("--output_path", default="./bitnet_b1_58-3B_quantized", type=str)
def main(args):
model = BitnetForCausalLM.from_pretrained(
args.hf_path,
device_map="auto",
low_cpu_mem_usage=True,
use_flash_attention_2=True,
torch_dtype=torch.float16,
).half()
tokenizer = BitnetTokenizer.from_pretrained(args.hf_path, use_fast=False)
model.quantize()
model.save_pretrained(args.output_path, max_shard_size="5GB")
print("Quantized model saved to", args.output_path)
if __name__ == "__main__":
args = parser.parse_args()
main(args)