Me1234567890 commited on
Commit
4d4ea94
·
1 Parent(s): 4b32a8d

Create KAI-1B_Demo.py

Browse files
Files changed (1) hide show
  1. KAI-1B_Demo.py +143 -0
KAI-1B_Demo.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mistral.cache import RotatingBufferCache
2
+ import torch
3
+ import fire
4
+ from typing import List
5
+ from pathlib import Path
6
+
7
+ from mistral.model import Transformer
8
+ from mistral.tokenizer import Tokenizer
9
+
10
+
11
+ def sample_top_p(probs: torch.Tensor, p: float):
12
+ assert 0 <= p <= 1
13
+
14
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
15
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
16
+ mask = probs_sum - probs_sort > p
17
+ probs_sort[mask] = 0.0
18
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
19
+ next_token = torch.multinomial(probs_sort, num_samples=1)
20
+ return torch.gather(probs_idx, -1, next_token)
21
+
22
+
23
+ def sample(logits: torch.Tensor, temperature: float, top_p: float):
24
+ if temperature > 0:
25
+ probs = torch.softmax(logits / temperature, dim=-1)
26
+ next_token = sample_top_p(probs, top_p)
27
+ else:
28
+ next_token = torch.argmax(logits, dim=-1).unsqueeze(0)
29
+
30
+ return next_token.reshape(-1)
31
+
32
+
33
+ @torch.inference_mode()
34
+ def generate(prompts: List[str], model: Transformer, tokenizer: Tokenizer, *, max_tokens: int, chunk_size: int = None, temperature: float = 0.7):
35
+ model = model.eval()
36
+ B, V = len(prompts), model.args.vocab_size
37
+
38
+ # Tokenize
39
+ encoded_prompts = [tokenizer.encode(prompt, bos=True) for prompt in prompts]
40
+ seqlens = [len(x) for x in encoded_prompts]
41
+
42
+ # Cache
43
+ cache_window = min(model.args.sliding_window, max(seqlens) + max_tokens)
44
+ cache = RotatingBufferCache(model.args.n_layers, model.args.max_batch_size, cache_window, model.args.n_kv_heads, model.args.head_dim)
45
+ cache.to(device=model.device, dtype=model.dtype)
46
+ cache.reset()
47
+
48
+ # Bookkeeping
49
+ logprobs = [[] for _ in range(B)]
50
+ last_token_prelogits = None
51
+
52
+ # One chunk if size not specified
53
+ max_prompt_len = max(seqlens)
54
+ if chunk_size is None:
55
+ chunk_size = max_prompt_len
56
+
57
+ # Encode prompt by chunks
58
+ for s in range(0, max_prompt_len, chunk_size):
59
+ prompt_chunks = [p[s:s+chunk_size] for p in encoded_prompts]
60
+ assert all(len(p) > 0 for p in prompt_chunks)
61
+ prelogits = model.forward(
62
+ torch.tensor(sum(prompt_chunks, []), device=model.device, dtype=torch.long),
63
+ seqlens=[len(p) for p in prompt_chunks],
64
+ cache=cache
65
+ )
66
+ logits = torch.log_softmax(prelogits, dim=-1)
67
+
68
+ if last_token_prelogits is not None:
69
+ # Pass > 1
70
+ last_token_logits = torch.log_softmax(last_token_prelogits, dim=-1)
71
+ for i_seq in range(B):
72
+ logprobs[i_seq].append(last_token_logits[i_seq, prompt_chunks[i_seq][0]].item())
73
+
74
+ offset = 0
75
+ for i_seq, sequence in enumerate(prompt_chunks):
76
+ logprobs[i_seq].extend([logits[offset + i, sequence[i + 1]].item() for i in range(len(sequence) - 1)])
77
+ offset += len(sequence)
78
+
79
+ last_token_prelogits = prelogits.index_select(0, torch.tensor([len(p) for p in prompt_chunks], device=prelogits.device).cumsum(dim=0) - 1)
80
+ assert last_token_prelogits.shape == (B, V)
81
+
82
+ # decode
83
+ generated_tokens = []
84
+ for i_token in range(max_tokens):
85
+ next_token = sample(last_token_prelogits, temperature=temperature, top_p=0.8)
86
+
87
+ last_token_logits = torch.log_softmax(last_token_prelogits, dim=-1)
88
+ for i in range(B):
89
+ logprobs[i].append(last_token_logits[i, next_token[i]].item())
90
+
91
+ generated_tokens.append(next_token[:, None])
92
+ last_token_prelogits = model.forward(next_token, seqlens=[1] * len(prompts), cache=cache)
93
+ assert last_token_prelogits.shape == (B, V)
94
+
95
+ generated_words = []
96
+ if generated_tokens:
97
+ generated_tokens = torch.cat(generated_tokens, 1)
98
+ for i, x in enumerate(encoded_prompts):
99
+ generated_words.append(tokenizer.decode(x + generated_tokens[i].tolist()))
100
+
101
+ return generated_words, logprobs
102
+
103
+
104
+ def interactive(model_path: str, max_tokens: int = 35, temperature: float = 0.7):
105
+ tokenizer = Tokenizer(str(Path(model_path) / "tokenizer.model"))
106
+ transformer = Transformer.from_folder(Path(model_path), max_batch_size=3)
107
+
108
+ while True:
109
+ prompt = input("Prompt: ")
110
+ res, _logprobs = generate(
111
+ [prompt],
112
+ transformer,
113
+ tokenizer,
114
+ max_tokens=max_tokens,
115
+ temperature=temperature,
116
+ )
117
+ print(res[0])
118
+ print("=====================")
119
+
120
+ def demo(model_path: str, max_tokens: int = 35, temperature: float = 0):
121
+ tokenizer = Tokenizer(str(Path(model_path) / "tokenizer.model"))
122
+ transformer = Transformer.from_folder(Path(model_path), max_batch_size=3)
123
+
124
+ res, _logprobs = generate(
125
+ [
126
+ "This is a test",
127
+ "This is another test",
128
+ "This is a third test, KAI is very good at testing. ",
129
+ ],
130
+ transformer,
131
+ tokenizer,
132
+ max_tokens=max_tokens,
133
+ temperature=temperature,
134
+ )
135
+ for x in res:
136
+ print(x)
137
+ print("=====================")
138
+
139
+ if __name__ == "__main__":
140
+ fire.Fire({
141
+ "interactive": interactive,
142
+ "demo": demo,
143
+ })