|
import argparse |
|
|
|
import torch |
|
|
|
from modeling_bitnet import BitnetForCausalLM |
|
from tokenization_bitnet import BitnetTokenizer |
|
|
|
torch.set_grad_enabled(False) |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str) |
|
parser.add_argument("--output_path", default="./bitnet_b1_58-3B_quantized", type=str) |
|
|
|
|
|
def main(args): |
|
model = BitnetForCausalLM.from_pretrained( |
|
args.hf_path, |
|
device_map="auto", |
|
low_cpu_mem_usage=True, |
|
use_flash_attention_2=True, |
|
torch_dtype=torch.float16, |
|
).half() |
|
tokenizer = BitnetTokenizer.from_pretrained(args.hf_path, use_fast=False) |
|
|
|
model.quantize() |
|
|
|
model.save_pretrained(args.output_path, max_shard_size="5GB") |
|
|
|
print("Quantized model saved to", args.output_path) |
|
|
|
|
|
if __name__ == "__main__": |
|
args = parser.parse_args() |
|
main(args) |
|
|