File size: 886 Bytes
29964ce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import argparse
import torch
from modeling_bitnet import BitnetForCausalLM
from tokenization_bitnet import BitnetTokenizer
torch.set_grad_enabled(False)
parser = argparse.ArgumentParser()
parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str)
parser.add_argument("--output_path", default="./bitnet_b1_58-3B_quantized", type=str)
def main(args):
model = BitnetForCausalLM.from_pretrained(
args.hf_path,
device_map="auto",
low_cpu_mem_usage=True,
use_flash_attention_2=True,
torch_dtype=torch.float16,
).half()
tokenizer = BitnetTokenizer.from_pretrained(args.hf_path, use_fast=False)
model.quantize()
model.save_pretrained(args.output_path, max_shard_size="5GB")
print("Quantized model saved to", args.output_path)
if __name__ == "__main__":
args = parser.parse_args()
main(args)
|