PlanetDOGE commited on
Commit
160bd5c
·
1 Parent(s): be55ff8

Delete KAI-1B_Demo.py

Browse files
Files changed (1) hide show
  1. KAI-1B_Demo.py +0 -143
KAI-1B_Demo.py DELETED
@@ -1,143 +0,0 @@
1
- from mistral.cache import RotatingBufferCache
2
- import torch
3
- import fire
4
- from typing import List
5
- from pathlib import Path
6
-
7
- from mistral.model import Transformer
8
- from mistral.tokenizer import Tokenizer
9
-
10
-
11
- def sample_top_p(probs: torch.Tensor, p: float):
12
- assert 0 <= p <= 1
13
-
14
- probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
15
- probs_sum = torch.cumsum(probs_sort, dim=-1)
16
- mask = probs_sum - probs_sort > p
17
- probs_sort[mask] = 0.0
18
- probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
19
- next_token = torch.multinomial(probs_sort, num_samples=1)
20
- return torch.gather(probs_idx, -1, next_token)
21
-
22
-
23
- def sample(logits: torch.Tensor, temperature: float, top_p: float):
24
- if temperature > 0:
25
- probs = torch.softmax(logits / temperature, dim=-1)
26
- next_token = sample_top_p(probs, top_p)
27
- else:
28
- next_token = torch.argmax(logits, dim=-1).unsqueeze(0)
29
-
30
- return next_token.reshape(-1)
31
-
32
-
33
- @torch.inference_mode()
34
- def generate(prompts: List[str], model: Transformer, tokenizer: Tokenizer, *, max_tokens: int, chunk_size: int = None, temperature: float = 0.7):
35
- model = model.eval()
36
- B, V = len(prompts), model.args.vocab_size
37
-
38
- # Tokenize
39
- encoded_prompts = [tokenizer.encode(prompt, bos=True) for prompt in prompts]
40
- seqlens = [len(x) for x in encoded_prompts]
41
-
42
- # Cache
43
- cache_window = min(model.args.sliding_window, max(seqlens) + max_tokens)
44
- cache = RotatingBufferCache(model.args.n_layers, model.args.max_batch_size, cache_window, model.args.n_kv_heads, model.args.head_dim)
45
- cache.to(device=model.device, dtype=model.dtype)
46
- cache.reset()
47
-
48
- # Bookkeeping
49
- logprobs = [[] for _ in range(B)]
50
- last_token_prelogits = None
51
-
52
- # One chunk if size not specified
53
- max_prompt_len = max(seqlens)
54
- if chunk_size is None:
55
- chunk_size = max_prompt_len
56
-
57
- # Encode prompt by chunks
58
- for s in range(0, max_prompt_len, chunk_size):
59
- prompt_chunks = [p[s:s+chunk_size] for p in encoded_prompts]
60
- assert all(len(p) > 0 for p in prompt_chunks)
61
- prelogits = model.forward(
62
- torch.tensor(sum(prompt_chunks, []), device=model.device, dtype=torch.long),
63
- seqlens=[len(p) for p in prompt_chunks],
64
- cache=cache
65
- )
66
- logits = torch.log_softmax(prelogits, dim=-1)
67
-
68
- if last_token_prelogits is not None:
69
- # Pass > 1
70
- last_token_logits = torch.log_softmax(last_token_prelogits, dim=-1)
71
- for i_seq in range(B):
72
- logprobs[i_seq].append(last_token_logits[i_seq, prompt_chunks[i_seq][0]].item())
73
-
74
- offset = 0
75
- for i_seq, sequence in enumerate(prompt_chunks):
76
- logprobs[i_seq].extend([logits[offset + i, sequence[i + 1]].item() for i in range(len(sequence) - 1)])
77
- offset += len(sequence)
78
-
79
- last_token_prelogits = prelogits.index_select(0, torch.tensor([len(p) for p in prompt_chunks], device=prelogits.device).cumsum(dim=0) - 1)
80
- assert last_token_prelogits.shape == (B, V)
81
-
82
- # decode
83
- generated_tokens = []
84
- for i_token in range(max_tokens):
85
- next_token = sample(last_token_prelogits, temperature=temperature, top_p=0.8)
86
-
87
- last_token_logits = torch.log_softmax(last_token_prelogits, dim=-1)
88
- for i in range(B):
89
- logprobs[i].append(last_token_logits[i, next_token[i]].item())
90
-
91
- generated_tokens.append(next_token[:, None])
92
- last_token_prelogits = model.forward(next_token, seqlens=[1] * len(prompts), cache=cache)
93
- assert last_token_prelogits.shape == (B, V)
94
-
95
- generated_words = []
96
- if generated_tokens:
97
- generated_tokens = torch.cat(generated_tokens, 1)
98
- for i, x in enumerate(encoded_prompts):
99
- generated_words.append(tokenizer.decode(x + generated_tokens[i].tolist()))
100
-
101
- return generated_words, logprobs
102
-
103
-
104
- def interactive(model_path: str, max_tokens: int = 35, temperature: float = 0.7):
105
- tokenizer = Tokenizer(str(Path(model_path) / "tokenizer.model"))
106
- transformer = Transformer.from_folder(Path(model_path), max_batch_size=3)
107
-
108
- while True:
109
- prompt = input("Prompt: ")
110
- res, _logprobs = generate(
111
- [prompt],
112
- transformer,
113
- tokenizer,
114
- max_tokens=max_tokens,
115
- temperature=temperature,
116
- )
117
- print(res[0])
118
- print("=====================")
119
-
120
- def demo(model_path: str, max_tokens: int = 35, temperature: float = 0):
121
- tokenizer = Tokenizer(str(Path(model_path) / "tokenizer.model"))
122
- transformer = Transformer.from_folder(Path(model_path), max_batch_size=3)
123
-
124
- res, _logprobs = generate(
125
- [
126
- "This is a test",
127
- "This is another test",
128
- "This is a third test, KAI is very good at testing. ",
129
- ],
130
- transformer,
131
- tokenizer,
132
- max_tokens=max_tokens,
133
- temperature=temperature,
134
- )
135
- for x in res:
136
- print(x)
137
- print("=====================")
138
-
139
- if __name__ == "__main__":
140
- fire.Fire({
141
- "interactive": interactive,
142
- "demo": demo,
143
- })