kcarnold commited on
Commit
9b8968e
·
1 Parent(s): 898f051

sync up with the backend

Browse files
Files changed (5) hide show
  1. custom_llm.py +43 -5
  2. custom_llm_inference.py +70 -1
  3. pyproject.toml +15 -1
  4. test_llm_inference.py +65 -0
  5. uv.lock +0 -0
custom_llm.py CHANGED
@@ -5,6 +5,7 @@ from contextlib import asynccontextmanager
5
  from pathlib import Path
6
  from typing import Dict, List, Optional
7
 
 
8
  import torch
9
  import uvicorn
10
  from fastapi import FastAPI, HTTPException
@@ -12,7 +13,7 @@ from fastapi.middleware.cors import CORSMiddleware
12
  from fastapi.testclient import TestClient
13
  from transformers import AutoModelForCausalLM, AutoTokenizer
14
 
15
- from custom_llm_inference import get_highlights_inner, get_next_token_predictions_inner
16
 
17
  ml_models = {}
18
 
@@ -36,7 +37,12 @@ async def models_lifespan(app: FastAPI):
36
 
37
  ml_models["llm"] = llm = {
38
  'tokenizer': AutoTokenizer.from_pretrained(model_name),
39
- 'model': AutoModelForCausalLM.from_pretrained(model_name, device_map="auto" if USE_GPU else "cpu", torch_dtype=dtype)
 
 
 
 
 
40
  }
41
  print("Loaded llm with device map:")
42
  print(llm['model'].hf_device_map)
@@ -61,7 +67,7 @@ async def models_lifespan(app: FastAPI):
61
 
62
  start = time.time()
63
  response = client.get("/api/gen_revisions",
64
- params={"doc": test_doc, "prompt": test_prompt, "n": 1})
65
  print(f"Gen revisions endpoint: {time.time() - start:.2f}s")
66
 
67
  yield
@@ -132,7 +138,9 @@ def get_next_token_predictions(original_doc: str,
132
  def gen_revisions(
133
  prompt: str,
134
  doc: str,
135
- n: Optional[int] = 5):
 
 
136
 
137
 
138
  model = ml_models['llm']['model']
@@ -148,7 +156,7 @@ def gen_revisions(
148
 
149
  generations = model.generate(
150
  tokenized_chat, num_return_sequences=n,
151
- max_length=1024, do_sample=True, top_k=50, top_p=0.95, temperature=0.5,
152
  return_dict_in_generate=True, output_scores=True)
153
  generated_docs = tokenizer.batch_decode(generations.sequences, skip_special_tokens=True)
154
  #print(generations.scores)
@@ -166,5 +174,35 @@ def gen_revisions(
166
  }
167
 
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  if __name__ == "__main__":
170
  uvicorn.run(app, host="localhost", port=PORT)
 
5
  from pathlib import Path
6
  from typing import Dict, List, Optional
7
 
8
+ from pydantic import BaseModel
9
  import torch
10
  import uvicorn
11
  from fastapi import FastAPI, HTTPException
 
13
  from fastapi.testclient import TestClient
14
  from transformers import AutoModelForCausalLM, AutoTokenizer
15
 
16
+ from custom_llm_inference import get_highlights_inner, get_next_token_predictions_inner, continue_messages_inner
17
 
18
  ml_models = {}
19
 
 
37
 
38
  ml_models["llm"] = llm = {
39
  'tokenizer': AutoTokenizer.from_pretrained(model_name),
40
+ 'model': AutoModelForCausalLM.from_pretrained(
41
+ model_name,
42
+ device_map="auto" if USE_GPU else "cpu",
43
+ torch_dtype=dtype,
44
+ attn_implementation='eager'
45
+ )
46
  }
47
  print("Loaded llm with device map:")
48
  print(llm['model'].hf_device_map)
 
67
 
68
  start = time.time()
69
  response = client.get("/api/gen_revisions",
70
+ params={"doc": test_doc, "prompt": test_prompt, "n": 1, "max_length": 16})
71
  print(f"Gen revisions endpoint: {time.time() - start:.2f}s")
72
 
73
  yield
 
138
  def gen_revisions(
139
  prompt: str,
140
  doc: str,
141
+ n: Optional[int] = 5,
142
+ max_length: Optional[int] = 1024,
143
+ ):
144
 
145
 
146
  model = ml_models['llm']['model']
 
156
 
157
  generations = model.generate(
158
  tokenized_chat, num_return_sequences=n,
159
+ max_new_tokens=max_length, do_sample=True, top_k=50, top_p=0.95, temperature=0.5,
160
  return_dict_in_generate=True, output_scores=True)
161
  generated_docs = tokenizer.batch_decode(generations.sequences, skip_special_tokens=True)
162
  #print(generations.scores)
 
174
  }
175
 
176
 
177
+ class Message(BaseModel):
178
+ role: str
179
+ content: str
180
+
181
+ class ContinueMessagesRequest(BaseModel):
182
+ messages: List[Message]
183
+ n_branch_tokens: int = 5
184
+ n_future_tokens: int = 5
185
+
186
+
187
+ @app.post('/api/continue_messages')
188
+ def continue_messages(request: ContinueMessagesRequest):
189
+
190
+ messages = [{"role": m.role, "content": m.content} for m in request.messages]
191
+ if len(messages) == 0:
192
+ raise HTTPException(status_code=400, detail="At least one message must be provided.")
193
+ n_branch_tokens = request.n_branch_tokens
194
+ n_future_tokens = request.n_future_tokens
195
+
196
+ model = ml_models['llm']['model']
197
+ tokenizer = ml_models['llm']['tokenizer']
198
+
199
+ generated_docs = continue_messages_inner(model, tokenizer, messages, n_branch_tokens, n_future_tokens)
200
+
201
+ return {
202
+ 'continuations': [dict(doc_text=doc) for doc in generated_docs]
203
+ }
204
+
205
+
206
+
207
  if __name__ == "__main__":
208
  uvicorn.run(app, host="localhost", port=PORT)
custom_llm_inference.py CHANGED
@@ -37,7 +37,8 @@ def get_highlights_inner(model, tokenizer, doc, prompt, updated_doc, k):
37
  updated_doc_ids = tokenize_doc_in_progress(tokenizer, updated_doc)
38
 
39
  joined_ids = torch.cat([tokenized_chat, updated_doc_ids])
40
- # Call the model
 
41
  with torch.no_grad():
42
  logits = model(joined_ids[None].to(model.device)).logits[0].cpu()
43
 
@@ -191,3 +192,71 @@ def get_next_token_predictions_slow(
191
 
192
  decoded_next_tokens = tokenizer.batch_decode(lookahead_sequences, skip_special_tokens=True)
193
  return decoded_next_tokens, next_token_logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  updated_doc_ids = tokenize_doc_in_progress(tokenizer, updated_doc)
38
 
39
  joined_ids = torch.cat([tokenized_chat, updated_doc_ids])
40
+
41
+ # Compute the next-token logits for the entire document
42
  with torch.no_grad():
43
  logits = model(joined_ids[None].to(model.device)).logits[0].cpu()
44
 
 
192
 
193
  decoded_next_tokens = tokenizer.batch_decode(lookahead_sequences, skip_special_tokens=True)
194
  return decoded_next_tokens, next_token_logits
195
+
196
+
197
+
198
+ def continue_messages_inner(model, tokenizer, messages, n_branch_tokens, n_future_tokens):
199
+ device = model.device
200
+
201
+ final_message_is_assistant = messages[-1]['role'] == "assistant"
202
+ print(f"final_message_is_assistant: {final_message_is_assistant}")
203
+ # if final_message_is_assistant:
204
+ # tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, continue_final_message=True, return_tensors="pt").to(model.device)
205
+ # else:
206
+ # tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)
207
+ tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt", continue_final_message=True).to(model.device)
208
+
209
+ print(tokenizer.batch_decode(tokenized_chat, skip_special_tokens=False))
210
+
211
+ # This fails with
212
+ # RuntimeError: Index put requires the source and destination dtypes match, got BFloat16 for the destination and Float for the source.
213
+ # generations = model.generate(
214
+ # tokenized_chat,
215
+ # num_return_sequences=n_branch_tokens,
216
+ # num_beam_groups=n_branch_tokens, num_beams=n_branch_tokens,
217
+ # do_sample=False, max_new_tokens=n_future_tokens, diversity_penalty=1e5, top_k=None,
218
+ # return_dict_in_generate=True, output_scores=True)
219
+
220
+ # Instead, we'll do this in two steps:
221
+ # 1. Get the next token predictions for the k most likely continuations
222
+ from transformers.cache_utils import DynamicCache
223
+ past_key_values = DynamicCache()
224
+ with torch.no_grad():
225
+ model_outs = model(
226
+ tokenized_chat,
227
+ past_key_values=past_key_values,
228
+ output_hidden_states=True,
229
+ use_cache=True,
230
+ )
231
+ branch_tokens = model_outs.logits[0, -1].topk(n_branch_tokens).indices
232
+
233
+ hypotheses = branch_tokens.unsqueeze(1)
234
+ # Branch off the k most likely continuations
235
+ past_key_values.reorder_cache(torch.zeros((n_branch_tokens,), dtype=torch.long, device=device))
236
+
237
+ # 2. Generate the next n_future_tokens for each branch
238
+ for i in range(n_future_tokens):
239
+ position_id_for_final_token = tokenized_chat.shape[0] + i
240
+ cache_position = torch.full((1,), position_id_for_final_token, dtype=int, device=device)
241
+ final_token_ids = hypotheses[:, -1:]
242
+ with torch.no_grad():
243
+ model_outs = model(
244
+ final_token_ids,
245
+ past_key_values=past_key_values,
246
+ output_hidden_states=True,
247
+ use_cache=True,
248
+ cache_position=cache_position
249
+ )
250
+
251
+ # Grab the single most likely token from each of the k sequences
252
+ next_token_logits = model_outs.logits[:, -1]
253
+ vocab_size = model.config.vocab_size
254
+ assert next_token_logits.shape == (n_branch_tokens, vocab_size), f"{next_token_logits.shape=}, {n_branch_tokens=}, {vocab_size=}"
255
+ most_likely_token_ids = next_token_logits.argmax(dim=-1)
256
+ hypotheses = torch.cat([
257
+ hypotheses,
258
+ most_likely_token_ids.unsqueeze(1)
259
+ ], dim=1)
260
+
261
+ generated_docs = tokenizer.batch_decode(hypotheses, skip_special_tokens=True)
262
+ return generated_docs
pyproject.toml CHANGED
@@ -3,9 +3,23 @@ name = "writing-prototypes"
3
  version = "0.1.0"
4
  description = "Add your description here"
5
  readme = "README.md"
6
- requires-python = ">=3.10"
7
  dependencies = [
 
8
  "pandas>=2.2.3",
 
9
  "requests>=2.32.3",
10
  "streamlit==1.40.1",
11
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  version = "0.1.0"
4
  description = "Add your description here"
5
  readme = "README.md"
6
+ requires-python = ">=3.11,<3.13"
7
  dependencies = [
8
+ "fastapi>=0.115.8",
9
  "pandas>=2.2.3",
10
+ "pydantic>=2.10.6",
11
  "requests>=2.32.3",
12
  "streamlit==1.40.1",
13
  ]
14
+
15
+ [dependency-groups]
16
+ gpu = [
17
+ "accelerate>=1.1.1",
18
+ "torch>=2.5.1",
19
+ "transformers>=4.46.2",
20
+ "tokenizers>=0.21.0",
21
+ ]
22
+ dev = [
23
+ "ipython>=8.32.0",
24
+ "marimo>=0.10.6",
25
+ ]
test_llm_inference.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import custom_llm_inference
5
+ from transformers.cache_utils import DynamicCache
6
+
7
+ @pytest.fixture
8
+ def model_and_tokenizer():
9
+ model_name = 'google/gemma-2-2b-it'
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+ if tokenizer.bos_token_id is None:
12
+ tokenizer.bos_token_id = tokenizer.pad_token_id
13
+ model = AutoModelForCausalLM.from_pretrained(
14
+ model_name,
15
+ device_map="cpu",
16
+ #torch_dtype=torch.float16
17
+ )
18
+ return model, tokenizer
19
+
20
+ @pytest.fixture
21
+ def sample_inputs():
22
+ doc = "The quick brown fox loves to jump over lazy dogs."
23
+ prompt = "Rewrite this document to make more sense."
24
+ doc_in_progress = "Sure, here's the document rewritten as requested:\n\nA fox,"
25
+ return doc, prompt, doc_in_progress
26
+
27
+ def test_get_next_token_predictions(model_and_tokenizer, sample_inputs):
28
+ model, tokenizer = model_and_tokenizer
29
+ doc, prompt, doc_in_progress = sample_inputs
30
+
31
+ predictions = custom_llm_inference.get_next_token_predictions_slow(
32
+ model, tokenizer, doc, prompt, doc_in_progress=doc_in_progress, k=5
33
+ )
34
+
35
+ assert len(predictions) == 2 # Should return (token_texts, logits)
36
+ assert len(predictions[0]) == 5 # Should return k=5 predictions
37
+ assert predictions[1].shape[1] == model.config.vocab_size
38
+
39
+ def test_get_tokenized_chat(model_and_tokenizer, sample_inputs):
40
+ model, tokenizer = model_and_tokenizer
41
+ doc, prompt, _ = sample_inputs
42
+
43
+ tokenized_chat = custom_llm_inference.get_tokenized_chat(tokenizer, prompt, doc)
44
+
45
+ assert isinstance(tokenized_chat, torch.Tensor)
46
+ assert tokenized_chat.dim() == 1
47
+ assert tokenized_chat.dtype == torch.int64
48
+
49
+ def test_highlights(model_and_tokenizer, sample_inputs):
50
+ model, tokenizer = model_and_tokenizer
51
+ doc, prompt, updated_doc = sample_inputs
52
+
53
+ highlights = custom_llm_inference.get_highlights_inner(
54
+ model, tokenizer, doc, prompt, updated_doc=updated_doc, k=5
55
+ )
56
+
57
+ assert isinstance(highlights, list)
58
+ assert len(highlights) > 0
59
+ for h in highlights:
60
+ assert h['start'] >= 0
61
+ assert h['end'] >= h['start']
62
+ assert isinstance(h['token'], str)
63
+ assert isinstance(h['token_loss'], float)
64
+ assert isinstance(h['most_likely_token'], str)
65
+ assert isinstance(h['topk_tokens'], list)
uv.lock ADDED
The diff for this file is too large to render. See raw diff