open_lm
File size: 3,820 Bytes
e08f01f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""Script to evaluate a model on a validation set. Based on scripts/generate.py from open_lm repo.
"""
import argparse
import json
import re

import torch

from open_lm.evaluate import evaluate_loop
from open_lm.data import get_data
from open_lm.model import create_model
from open_lm.distributed import init_distributed_device
from open_lm.params import parse_args

from scripts.generate_without_hf import Generator, GenerationArgs

def generate_model_jsonl(params):
    params_to_width_depth_dict = {5: (96, 3),
                    7: (128, 4),
                    9: (160, 5),
                    15: (224, 6),
                    22: (288, 8),
                    28: (320, 9),
                    37: (384, 10),
                    57: (480, 12),
                    84: (576, 14),
                    108: (640, 15),
                    149: (704, 18),
                    220: (832, 21),
                    347: (1024, 23),
                    455: (1120, 26),
                    611: (1312, 26),
                    901: (1504, 30)
                    }

    width, depth = params_to_width_depth_dict[params]
    filepath = f"layers={depth}_hidden-dim={width}.json"
    data = {
        "hidden_dim": width,
        "n_layers": depth,
        "n_heads": 4,
        "seq_len": 2048,
        "vocab_size": 50432,
        "post_embed_norm": False,
        "weight_tying": False,
        "qk_norm": True
    }

    with open(filepath, 'w') as file:
        file.write(json.dumps(data) + '\n')
    return filepath


class ModelArgs:
    def __init__(self, params, val_data, val_data_key):
        default_params = vars(parse_args(""))
        for k, v in default_params.items():
            setattr(self, k, v)
        self.model = generate_model_jsonl(params)
        self.val_data = [val_data]
        self.val_data_key = [val_data_key]
        self.per_gpu_val_batch_size = 16
        self.vocab_size = 50432
        self.seq_len = 2048
        self.wandb = False


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint", default="path/to/checkpoint")

    parser.add_argument("--val-data", default="", help="Path to validation data. If empty, generate text.")
    parser.add_argument("--val-data-key", default="json.gz")

    parser.add_argument("--input-text", default="", type=str, help="Input text to generate from. If empty, evaluate on validation data.")
    parser.add_argument("--max-gen-len", default=200, type=int)
    parser.add_argument("--temperature", default=0.8, type=float)
    parser.add_argument("--top-p", default=0.95, type=float)

    args = parser.parse_args()
    params = int(re.search(r"params=(\d+)", args.checkpoint).group(1))

    checkpoint = torch.load(args.checkpoint)
    state_dict = checkpoint["state_dict"]
    state_dict = {x.replace("module.", ""): y for x, y in state_dict.items()}
    model_args = ModelArgs(params=params, val_data=args.val_data, val_data_key=args.val_data_key)
    device = init_distributed_device(model_args)
    model_args.device = device
    model = create_model(model_args)
    model.load_state_dict(state_dict)
    model.eval().cuda()
    if args.val_data != "":
        data = get_data(
            model_args,
            skip_train=True,
        )
        metrics = evaluate_loop(model, data["val_list"], 0, model_args, None)
        print(metrics)
    elif args.input_text != "":
        model = model.half()
        generator = Generator(model)
        input_text = [
            args.input_text,
        ]
        output = generator.generate(
            input_text,
            GenerationArgs(args.max_gen_len, args.temperature, args.top_p),
        )
        print("".join(output))
        
    else:
        print("Please provide either --val-data or --input-text")

if __name__ == "__main__":
    main()