kcarnold commited on
Commit
329490e
·
1 Parent(s): 9b8968e

Show a problem with the current approach.

Browse files
Files changed (3) hide show
  1. pyproject.toml +1 -0
  2. test_llm_inference.py +64 -1
  3. uv.lock +35 -0
pyproject.toml CHANGED
@@ -8,6 +8,7 @@ dependencies = [
8
  "fastapi>=0.115.8",
9
  "pandas>=2.2.3",
10
  "pydantic>=2.10.6",
 
11
  "requests>=2.32.3",
12
  "streamlit==1.40.1",
13
  ]
 
8
  "fastapi>=0.115.8",
9
  "pandas>=2.2.3",
10
  "pydantic>=2.10.6",
11
+ "pytest>=8.3.4",
12
  "requests>=2.32.3",
13
  "streamlit==1.40.1",
14
  ]
test_llm_inference.py CHANGED
@@ -13,7 +13,7 @@ def model_and_tokenizer():
13
  model = AutoModelForCausalLM.from_pretrained(
14
  model_name,
15
  device_map="cpu",
16
- #torch_dtype=torch.float16
17
  )
18
  return model, tokenizer
19
 
@@ -63,3 +63,66 @@ def test_highlights(model_and_tokenizer, sample_inputs):
63
  assert isinstance(h['token_loss'], float)
64
  assert isinstance(h['most_likely_token'], str)
65
  assert isinstance(h['topk_tokens'], list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  model = AutoModelForCausalLM.from_pretrained(
14
  model_name,
15
  device_map="cpu",
16
+ torch_dtype=torch.float16
17
  )
18
  return model, tokenizer
19
 
 
63
  assert isinstance(h['token_loss'], float)
64
  assert isinstance(h['most_likely_token'], str)
65
  assert isinstance(h['topk_tokens'], list)
66
+
67
+ def compare_lookahead_predictions(model, tokenizer, doc, prompt, doc_in_progress, k=5):
68
+ """
69
+ Extracts and compares the next token predictions between the fast method and slow method.
70
+ Returns the differences between the two approaches for analysis.
71
+ """
72
+ # Get predictions from the fast method (using cache)
73
+ fast_tokens, fast_logits = custom_llm_inference.get_next_token_predictions_inner(
74
+ model, tokenizer, doc, prompt, doc_in_progress, k
75
+ )
76
+
77
+ # Get predictions from the slow method (recomputing for each token)
78
+ slow_tokens, slow_logits = custom_llm_inference.get_next_token_predictions_slow(
79
+ model, tokenizer, doc, prompt, doc_in_progress, k
80
+ )
81
+
82
+ # Compare the decoded tokens (this is what users will see)
83
+ token_matches = [fast == slow for fast, slow in zip(fast_tokens, slow_tokens)]
84
+
85
+ # Calculate the difference in logits for most likely next tokens
86
+ fast_most_likely = fast_logits.argmax(dim=-1)
87
+ slow_most_likely = slow_logits.argmax(dim=-1)
88
+ logit_match = torch.eq(fast_most_likely, slow_most_likely).cpu().numpy()
89
+
90
+ # Calculate numerical difference in logits
91
+ logit_diff_norm = torch.linalg.vector_norm((fast_logits - slow_logits).to(torch.float32), dim=1).cpu().numpy()
92
+
93
+ return {
94
+ "fast_tokens": fast_tokens,
95
+ "slow_tokens": slow_tokens,
96
+ "token_matches": token_matches,
97
+ "token_match_all": all(token_matches),
98
+ "logit_match": logit_match,
99
+ "logit_diff_norm": logit_diff_norm
100
+ }
101
+
102
+ def test_lookahead_token_consistency(model_and_tokenizer, sample_inputs):
103
+ """
104
+ Test that demonstrates the potential issue with cache position indices
105
+ when generating lookahead tokens.
106
+ """
107
+ model, tokenizer = model_and_tokenizer
108
+ doc, prompt, doc_in_progress = sample_inputs
109
+
110
+ results = compare_lookahead_predictions(model, tokenizer, doc, prompt, doc_in_progress)
111
+
112
+ # Check if the tokens are the same
113
+ assert results["token_match_all"], (
114
+ f"Fast and slow methods produced different tokens.\n"
115
+ f"Fast: {results['fast_tokens']}\n"
116
+ f"Slow: {results['slow_tokens']}"
117
+ )
118
+
119
+ # Check if the most likely next tokens based on logits are the same
120
+ assert all(results["logit_match"]), (
121
+ f"Fast and slow methods predicted different most likely next tokens"
122
+ )
123
+
124
+ # Check that the logit differences are minimal
125
+ # This might fail if there's a bug in the cache position indices
126
+ assert all(diff < 1e-4 for diff in results["logit_diff_norm"]), (
127
+ f"Significant difference in logits between fast and slow methods: {results['logit_diff_norm']}"
128
+ )
uv.lock CHANGED
@@ -287,6 +287,15 @@ wheels = [
287
  { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442 },
288
  ]
289
 
 
 
 
 
 
 
 
 
 
290
  [[package]]
291
  name = "ipython"
292
  version = "8.32.0"
@@ -719,6 +728,15 @@ wheels = [
719
  { url = "https://files.pythonhosted.org/packages/0b/30/2b61876e2722374558b871dfbfcbe4e406626d63f4f6ed92e9c8e24cac37/pillow-11.0.0-cp312-cp312-win_arm64.whl", hash = "sha256:27a7860107500d813fcd203b4ea19b04babe79448268403172782754870dac25", size = 2254890 },
720
  ]
721
 
 
 
 
 
 
 
 
 
 
722
  [[package]]
723
  name = "prompt-toolkit"
724
  version = "3.0.50"
@@ -917,6 +935,21 @@ wheels = [
917
  { url = "https://files.pythonhosted.org/packages/eb/f5/b9e2a42aa8f9e34d52d66de87941ecd236570c7ed2e87775ed23bbe4e224/pymdown_extensions-10.14.3-py3-none-any.whl", hash = "sha256:05e0bee73d64b9c71a4ae17c72abc2f700e8bc8403755a00580b49a4e9f189e9", size = 264467 },
918
  ]
919
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
920
  [[package]]
921
  name = "python-dateutil"
922
  version = "2.9.0.post0"
@@ -1505,6 +1538,7 @@ dependencies = [
1505
  { name = "fastapi" },
1506
  { name = "pandas" },
1507
  { name = "pydantic" },
 
1508
  { name = "requests" },
1509
  { name = "streamlit" },
1510
  ]
@@ -1526,6 +1560,7 @@ requires-dist = [
1526
  { name = "fastapi", specifier = ">=0.115.8" },
1527
  { name = "pandas", specifier = ">=2.2.3" },
1528
  { name = "pydantic", specifier = ">=2.10.6" },
 
1529
  { name = "requests", specifier = ">=2.32.3" },
1530
  { name = "streamlit", specifier = "==1.40.1" },
1531
  ]
 
287
  { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442 },
288
  ]
289
 
290
+ [[package]]
291
+ name = "iniconfig"
292
+ version = "2.0.0"
293
+ source = { registry = "https://pypi.org/simple" }
294
+ sdist = { url = "https://files.pythonhosted.org/packages/d7/4b/cbd8e699e64a6f16ca3a8220661b5f83792b3017d0f79807cb8708d33913/iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3", size = 4646 }
295
+ wheels = [
296
+ { url = "https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374", size = 5892 },
297
+ ]
298
+
299
  [[package]]
300
  name = "ipython"
301
  version = "8.32.0"
 
728
  { url = "https://files.pythonhosted.org/packages/0b/30/2b61876e2722374558b871dfbfcbe4e406626d63f4f6ed92e9c8e24cac37/pillow-11.0.0-cp312-cp312-win_arm64.whl", hash = "sha256:27a7860107500d813fcd203b4ea19b04babe79448268403172782754870dac25", size = 2254890 },
729
  ]
730
 
731
+ [[package]]
732
+ name = "pluggy"
733
+ version = "1.5.0"
734
+ source = { registry = "https://pypi.org/simple" }
735
+ sdist = { url = "https://files.pythonhosted.org/packages/96/2d/02d4312c973c6050a18b314a5ad0b3210edb65a906f868e31c111dede4a6/pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1", size = 67955 }
736
+ wheels = [
737
+ { url = "https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669", size = 20556 },
738
+ ]
739
+
740
  [[package]]
741
  name = "prompt-toolkit"
742
  version = "3.0.50"
 
935
  { url = "https://files.pythonhosted.org/packages/eb/f5/b9e2a42aa8f9e34d52d66de87941ecd236570c7ed2e87775ed23bbe4e224/pymdown_extensions-10.14.3-py3-none-any.whl", hash = "sha256:05e0bee73d64b9c71a4ae17c72abc2f700e8bc8403755a00580b49a4e9f189e9", size = 264467 },
936
  ]
937
 
938
+ [[package]]
939
+ name = "pytest"
940
+ version = "8.3.4"
941
+ source = { registry = "https://pypi.org/simple" }
942
+ dependencies = [
943
+ { name = "colorama", marker = "sys_platform == 'win32'" },
944
+ { name = "iniconfig" },
945
+ { name = "packaging" },
946
+ { name = "pluggy" },
947
+ ]
948
+ sdist = { url = "https://files.pythonhosted.org/packages/05/35/30e0d83068951d90a01852cb1cef56e5d8a09d20c7f511634cc2f7e0372a/pytest-8.3.4.tar.gz", hash = "sha256:965370d062bce11e73868e0335abac31b4d3de0e82f4007408d242b4f8610761", size = 1445919 }
949
+ wheels = [
950
+ { url = "https://files.pythonhosted.org/packages/11/92/76a1c94d3afee238333bc0a42b82935dd8f9cf8ce9e336ff87ee14d9e1cf/pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6", size = 343083 },
951
+ ]
952
+
953
  [[package]]
954
  name = "python-dateutil"
955
  version = "2.9.0.post0"
 
1538
  { name = "fastapi" },
1539
  { name = "pandas" },
1540
  { name = "pydantic" },
1541
+ { name = "pytest" },
1542
  { name = "requests" },
1543
  { name = "streamlit" },
1544
  ]
 
1560
  { name = "fastapi", specifier = ">=0.115.8" },
1561
  { name = "pandas", specifier = ">=2.2.3" },
1562
  { name = "pydantic", specifier = ">=2.10.6" },
1563
+ { name = "pytest", specifier = ">=8.3.4" },
1564
  { name = "requests", specifier = ">=2.32.3" },
1565
  { name = "streamlit", specifier = "==1.40.1" },
1566
  ]