File size: 2,350 Bytes
29964ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import argparse
import math
import random

import torch
from tqdm import tqdm

from eval_utils import get_test_dataset
from modeling_bitnet import BitnetForCausalLM
from tokenization_bitnet import BitnetTokenizer

torch.set_grad_enabled(False)

parser = argparse.ArgumentParser()
parser.add_argument("--seed", default=0, type=int)
parser.add_argument("--hf_path", default="./", type=str) # 1bitLLM/bitnet_b1_58-3B
parser.add_argument("--seqlen", default=2048, type=int)
parser.add_argument("--max_dataset_size", default=100000, type=int)


def calulate_loss(model, input, loss_fct):
    output = model(
        input, use_cache=False, output_hidden_states=False, output_attentions=False
    )[0]
    shift_logits = output[:, :-1, :].contiguous()
    shift_labels = input[:, 1:]
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    return loss


def main(args):
    datasets = ["wikitext2"]  # ['c4', 'wikitext2']

    model = BitnetForCausalLM.from_pretrained(
        args.hf_path,
        device_map="auto",
        low_cpu_mem_usage=True,
        use_flash_attention_2=True,
        torch_dtype=torch.float16,
    ).half()
    tokenizer = BitnetTokenizer.from_pretrained(args.hf_path, use_fast=False)
    loss_fct = torch.nn.CrossEntropyLoss(reduction="sum").cuda()

    ppl = []

    for dataset in datasets:
        testdata = get_test_dataset(dataset, tokenizer, seqlen=args.seqlen)
        acc_loss, count = 0.0, 0

        dataset_size = (
            args.max_dataset_size
            if len(testdata) > args.max_dataset_size
            else len(testdata)
        )
        progress = tqdm(range(dataset_size))
        for ii in progress:
            input = torch.Tensor(testdata[ii]).long().cuda().view(1, -1)
            loss = calulate_loss(model, input, loss_fct)
            count += input.size(-1) - 1
            acc_loss += loss.item()
            progress.set_description(f"avg_loss = {acc_loss/ count / math.log(2)}")

        avg_loss = acc_loss / count / math.log(2)
        ppl.append(2**avg_loss)
        print("{} PPL: {}".format(dataset, ppl[-1]))

    print(ppl)
    print("Avg PPL:", sum(ppl) / len(ppl))


if __name__ == "__main__":
    torch.set_grad_enabled(False)
    args = parser.parse_args()
    random.seed(args.seed)
    torch.random.manual_seed(args.seed)
    main(args)