Spaces:
Running
Running
Show a problem with the current approach.
Browse files- pyproject.toml +1 -0
- test_llm_inference.py +64 -1
- uv.lock +35 -0
pyproject.toml
CHANGED
@@ -8,6 +8,7 @@ dependencies = [
|
|
8 |
"fastapi>=0.115.8",
|
9 |
"pandas>=2.2.3",
|
10 |
"pydantic>=2.10.6",
|
|
|
11 |
"requests>=2.32.3",
|
12 |
"streamlit==1.40.1",
|
13 |
]
|
|
|
8 |
"fastapi>=0.115.8",
|
9 |
"pandas>=2.2.3",
|
10 |
"pydantic>=2.10.6",
|
11 |
+
"pytest>=8.3.4",
|
12 |
"requests>=2.32.3",
|
13 |
"streamlit==1.40.1",
|
14 |
]
|
test_llm_inference.py
CHANGED
@@ -13,7 +13,7 @@ def model_and_tokenizer():
|
|
13 |
model = AutoModelForCausalLM.from_pretrained(
|
14 |
model_name,
|
15 |
device_map="cpu",
|
16 |
-
|
17 |
)
|
18 |
return model, tokenizer
|
19 |
|
@@ -63,3 +63,66 @@ def test_highlights(model_and_tokenizer, sample_inputs):
|
|
63 |
assert isinstance(h['token_loss'], float)
|
64 |
assert isinstance(h['most_likely_token'], str)
|
65 |
assert isinstance(h['topk_tokens'], list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
model = AutoModelForCausalLM.from_pretrained(
|
14 |
model_name,
|
15 |
device_map="cpu",
|
16 |
+
torch_dtype=torch.float16
|
17 |
)
|
18 |
return model, tokenizer
|
19 |
|
|
|
63 |
assert isinstance(h['token_loss'], float)
|
64 |
assert isinstance(h['most_likely_token'], str)
|
65 |
assert isinstance(h['topk_tokens'], list)
|
66 |
+
|
67 |
+
def compare_lookahead_predictions(model, tokenizer, doc, prompt, doc_in_progress, k=5):
|
68 |
+
"""
|
69 |
+
Extracts and compares the next token predictions between the fast method and slow method.
|
70 |
+
Returns the differences between the two approaches for analysis.
|
71 |
+
"""
|
72 |
+
# Get predictions from the fast method (using cache)
|
73 |
+
fast_tokens, fast_logits = custom_llm_inference.get_next_token_predictions_inner(
|
74 |
+
model, tokenizer, doc, prompt, doc_in_progress, k
|
75 |
+
)
|
76 |
+
|
77 |
+
# Get predictions from the slow method (recomputing for each token)
|
78 |
+
slow_tokens, slow_logits = custom_llm_inference.get_next_token_predictions_slow(
|
79 |
+
model, tokenizer, doc, prompt, doc_in_progress, k
|
80 |
+
)
|
81 |
+
|
82 |
+
# Compare the decoded tokens (this is what users will see)
|
83 |
+
token_matches = [fast == slow for fast, slow in zip(fast_tokens, slow_tokens)]
|
84 |
+
|
85 |
+
# Calculate the difference in logits for most likely next tokens
|
86 |
+
fast_most_likely = fast_logits.argmax(dim=-1)
|
87 |
+
slow_most_likely = slow_logits.argmax(dim=-1)
|
88 |
+
logit_match = torch.eq(fast_most_likely, slow_most_likely).cpu().numpy()
|
89 |
+
|
90 |
+
# Calculate numerical difference in logits
|
91 |
+
logit_diff_norm = torch.linalg.vector_norm((fast_logits - slow_logits).to(torch.float32), dim=1).cpu().numpy()
|
92 |
+
|
93 |
+
return {
|
94 |
+
"fast_tokens": fast_tokens,
|
95 |
+
"slow_tokens": slow_tokens,
|
96 |
+
"token_matches": token_matches,
|
97 |
+
"token_match_all": all(token_matches),
|
98 |
+
"logit_match": logit_match,
|
99 |
+
"logit_diff_norm": logit_diff_norm
|
100 |
+
}
|
101 |
+
|
102 |
+
def test_lookahead_token_consistency(model_and_tokenizer, sample_inputs):
|
103 |
+
"""
|
104 |
+
Test that demonstrates the potential issue with cache position indices
|
105 |
+
when generating lookahead tokens.
|
106 |
+
"""
|
107 |
+
model, tokenizer = model_and_tokenizer
|
108 |
+
doc, prompt, doc_in_progress = sample_inputs
|
109 |
+
|
110 |
+
results = compare_lookahead_predictions(model, tokenizer, doc, prompt, doc_in_progress)
|
111 |
+
|
112 |
+
# Check if the tokens are the same
|
113 |
+
assert results["token_match_all"], (
|
114 |
+
f"Fast and slow methods produced different tokens.\n"
|
115 |
+
f"Fast: {results['fast_tokens']}\n"
|
116 |
+
f"Slow: {results['slow_tokens']}"
|
117 |
+
)
|
118 |
+
|
119 |
+
# Check if the most likely next tokens based on logits are the same
|
120 |
+
assert all(results["logit_match"]), (
|
121 |
+
f"Fast and slow methods predicted different most likely next tokens"
|
122 |
+
)
|
123 |
+
|
124 |
+
# Check that the logit differences are minimal
|
125 |
+
# This might fail if there's a bug in the cache position indices
|
126 |
+
assert all(diff < 1e-4 for diff in results["logit_diff_norm"]), (
|
127 |
+
f"Significant difference in logits between fast and slow methods: {results['logit_diff_norm']}"
|
128 |
+
)
|
uv.lock
CHANGED
@@ -287,6 +287,15 @@ wheels = [
|
|
287 |
{ url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442 },
|
288 |
]
|
289 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
290 |
[[package]]
|
291 |
name = "ipython"
|
292 |
version = "8.32.0"
|
@@ -719,6 +728,15 @@ wheels = [
|
|
719 |
{ url = "https://files.pythonhosted.org/packages/0b/30/2b61876e2722374558b871dfbfcbe4e406626d63f4f6ed92e9c8e24cac37/pillow-11.0.0-cp312-cp312-win_arm64.whl", hash = "sha256:27a7860107500d813fcd203b4ea19b04babe79448268403172782754870dac25", size = 2254890 },
|
720 |
]
|
721 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
722 |
[[package]]
|
723 |
name = "prompt-toolkit"
|
724 |
version = "3.0.50"
|
@@ -917,6 +935,21 @@ wheels = [
|
|
917 |
{ url = "https://files.pythonhosted.org/packages/eb/f5/b9e2a42aa8f9e34d52d66de87941ecd236570c7ed2e87775ed23bbe4e224/pymdown_extensions-10.14.3-py3-none-any.whl", hash = "sha256:05e0bee73d64b9c71a4ae17c72abc2f700e8bc8403755a00580b49a4e9f189e9", size = 264467 },
|
918 |
]
|
919 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
920 |
[[package]]
|
921 |
name = "python-dateutil"
|
922 |
version = "2.9.0.post0"
|
@@ -1505,6 +1538,7 @@ dependencies = [
|
|
1505 |
{ name = "fastapi" },
|
1506 |
{ name = "pandas" },
|
1507 |
{ name = "pydantic" },
|
|
|
1508 |
{ name = "requests" },
|
1509 |
{ name = "streamlit" },
|
1510 |
]
|
@@ -1526,6 +1560,7 @@ requires-dist = [
|
|
1526 |
{ name = "fastapi", specifier = ">=0.115.8" },
|
1527 |
{ name = "pandas", specifier = ">=2.2.3" },
|
1528 |
{ name = "pydantic", specifier = ">=2.10.6" },
|
|
|
1529 |
{ name = "requests", specifier = ">=2.32.3" },
|
1530 |
{ name = "streamlit", specifier = "==1.40.1" },
|
1531 |
]
|
|
|
287 |
{ url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442 },
|
288 |
]
|
289 |
|
290 |
+
[[package]]
|
291 |
+
name = "iniconfig"
|
292 |
+
version = "2.0.0"
|
293 |
+
source = { registry = "https://pypi.org/simple" }
|
294 |
+
sdist = { url = "https://files.pythonhosted.org/packages/d7/4b/cbd8e699e64a6f16ca3a8220661b5f83792b3017d0f79807cb8708d33913/iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3", size = 4646 }
|
295 |
+
wheels = [
|
296 |
+
{ url = "https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374", size = 5892 },
|
297 |
+
]
|
298 |
+
|
299 |
[[package]]
|
300 |
name = "ipython"
|
301 |
version = "8.32.0"
|
|
|
728 |
{ url = "https://files.pythonhosted.org/packages/0b/30/2b61876e2722374558b871dfbfcbe4e406626d63f4f6ed92e9c8e24cac37/pillow-11.0.0-cp312-cp312-win_arm64.whl", hash = "sha256:27a7860107500d813fcd203b4ea19b04babe79448268403172782754870dac25", size = 2254890 },
|
729 |
]
|
730 |
|
731 |
+
[[package]]
|
732 |
+
name = "pluggy"
|
733 |
+
version = "1.5.0"
|
734 |
+
source = { registry = "https://pypi.org/simple" }
|
735 |
+
sdist = { url = "https://files.pythonhosted.org/packages/96/2d/02d4312c973c6050a18b314a5ad0b3210edb65a906f868e31c111dede4a6/pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1", size = 67955 }
|
736 |
+
wheels = [
|
737 |
+
{ url = "https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669", size = 20556 },
|
738 |
+
]
|
739 |
+
|
740 |
[[package]]
|
741 |
name = "prompt-toolkit"
|
742 |
version = "3.0.50"
|
|
|
935 |
{ url = "https://files.pythonhosted.org/packages/eb/f5/b9e2a42aa8f9e34d52d66de87941ecd236570c7ed2e87775ed23bbe4e224/pymdown_extensions-10.14.3-py3-none-any.whl", hash = "sha256:05e0bee73d64b9c71a4ae17c72abc2f700e8bc8403755a00580b49a4e9f189e9", size = 264467 },
|
936 |
]
|
937 |
|
938 |
+
[[package]]
|
939 |
+
name = "pytest"
|
940 |
+
version = "8.3.4"
|
941 |
+
source = { registry = "https://pypi.org/simple" }
|
942 |
+
dependencies = [
|
943 |
+
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
944 |
+
{ name = "iniconfig" },
|
945 |
+
{ name = "packaging" },
|
946 |
+
{ name = "pluggy" },
|
947 |
+
]
|
948 |
+
sdist = { url = "https://files.pythonhosted.org/packages/05/35/30e0d83068951d90a01852cb1cef56e5d8a09d20c7f511634cc2f7e0372a/pytest-8.3.4.tar.gz", hash = "sha256:965370d062bce11e73868e0335abac31b4d3de0e82f4007408d242b4f8610761", size = 1445919 }
|
949 |
+
wheels = [
|
950 |
+
{ url = "https://files.pythonhosted.org/packages/11/92/76a1c94d3afee238333bc0a42b82935dd8f9cf8ce9e336ff87ee14d9e1cf/pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6", size = 343083 },
|
951 |
+
]
|
952 |
+
|
953 |
[[package]]
|
954 |
name = "python-dateutil"
|
955 |
version = "2.9.0.post0"
|
|
|
1538 |
{ name = "fastapi" },
|
1539 |
{ name = "pandas" },
|
1540 |
{ name = "pydantic" },
|
1541 |
+
{ name = "pytest" },
|
1542 |
{ name = "requests" },
|
1543 |
{ name = "streamlit" },
|
1544 |
]
|
|
|
1560 |
{ name = "fastapi", specifier = ">=0.115.8" },
|
1561 |
{ name = "pandas", specifier = ">=2.2.3" },
|
1562 |
{ name = "pydantic", specifier = ">=2.10.6" },
|
1563 |
+
{ name = "pytest", specifier = ">=8.3.4" },
|
1564 |
{ name = "requests", specifier = ">=2.32.3" },
|
1565 |
{ name = "streamlit", specifier = "==1.40.1" },
|
1566 |
]
|