|
|
|
|
|
import spaces |
|
import requests |
|
import gradio as gr |
|
from bs4 import BeautifulSoup |
|
from transformers import pipeline |
|
|
|
from kvpress import ( |
|
ExpectedAttentionPress, |
|
KnormPress, |
|
RandomPress, |
|
SnapKVPress, |
|
StreamingLLMPress, |
|
TOVAPress, |
|
) |
|
|
|
press_dict = { |
|
"ExpectedAttentionPress": ExpectedAttentionPress, |
|
"KnormPress": KnormPress, |
|
"RandomPress": RandomPress, |
|
"SnapKVPress": SnapKVPress, |
|
"StreamingLLMPress": StreamingLLMPress, |
|
"TOVAPress": TOVAPress, |
|
} |
|
|
|
|
|
@spaces.GPU |
|
def process_request(url, question, press_name, compression_ratio): |
|
""" """ |
|
|
|
if press_name not in press_dict: |
|
return f"Invalid press type selected: {press_name}", -1, -1 |
|
|
|
|
|
try: |
|
content = requests.get(url).content |
|
except requests.exceptions.RequestException as e: |
|
return f"Error fetching the Wikipedia article: {str(e)}", -1, -1 |
|
|
|
try: |
|
|
|
soup = BeautifulSoup(content, "html.parser") |
|
context = "".join([p.text for p in soup.find_all("p")]) + "\n\n" |
|
|
|
|
|
press = press_dict[press_name](compression_ratio) |
|
num_tokens = pipe.tokenizer(context, return_tensors="pt")["input_ids"].shape[1] |
|
pred_answer = pipe(context, question=question, press=press)["answer"] |
|
|
|
return pred_answer, num_tokens, int(num_tokens * (1 - compression_ratio)) |
|
except Exception as e: |
|
if "CUDA out of memory" in str(e): |
|
return "Error: CUDA out of memory. Try using a smaller article or a lower compression ratio.", -1 |
|
else: |
|
return str(e), -1, -1 |
|
|
|
|
|
def gradio_interface(): |
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
""" |
|
# Wikipedia Article Question Answering with kvpress |
|
This demo uses the llama 3.1 8B Instruct model to answer questions about any given Wikipedia article. |
|
Under the hood, [kvpress](https://github.com/NVIDIA/kvpress) *compresses the key-value (KV) cache* associated with the article, helping reduce memory usage and accelerate decoding. |
|
|
|
**How to use:** |
|
1. Enter a Wikipedia article URL |
|
2. Type your question |
|
3. Select a press type and the desired compression ratio |
|
4. Press "Submit" to see the answer, along with token statistics before and after compression |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
url_input = gr.Textbox(label="Wikipedia Article URL", placeholder="Enter the Wikipedia article URL here") |
|
question_input = gr.Textbox(label="Question", placeholder="Type your question here") |
|
|
|
with gr.Row(): |
|
press_selector = gr.Dropdown( |
|
choices=list(press_dict.keys()), |
|
value="ExpectedAttentionPress", |
|
label="Select Press Type", |
|
) |
|
compression_slider = gr.Slider(minimum=0.0, maximum=0.9, step=0.1, value=0.5, label="Compression Ratio") |
|
|
|
output = gr.Textbox(label="Output", lines=10) |
|
output_num_tokens = gr.Number(label="Number of tokens before compression", interactive=False) |
|
output_compressed_num_tokens = gr.Number(label="Number of tokens after compression", interactive=False) |
|
|
|
submit_button = gr.Button("Submit") |
|
|
|
gr.Examples( |
|
examples=[ |
|
[ |
|
"https://en.wikipedia.org/wiki/Nvidia", |
|
"Complete this sentence: In May 2017, the program had 1,300 companies. As of March 2018, there were ", |
|
"ExpectedAttentionPress", |
|
0.5, |
|
], |
|
[ |
|
"https://en.wikipedia.org/wiki/Hugging_Face", |
|
"What was the original name of the transformers library ?", |
|
"ExpectedAttentionPress", |
|
0.5, |
|
], |
|
[ |
|
"https://en.wikipedia.org/wiki/World_Chess_Championship_2024", |
|
"On which move did the world chess championship end?", |
|
"ExpectedAttentionPress", |
|
0.5, |
|
], |
|
], |
|
inputs=[url_input, question_input, press_selector, compression_slider], |
|
) |
|
|
|
submit_button.click( |
|
process_request, |
|
inputs=[url_input, question_input, press_selector, compression_slider], |
|
outputs=[output, output_num_tokens, output_compressed_num_tokens], |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
device = "cuda:0" |
|
ckpt = "meta-llama/Meta-Llama-3.1-8B-Instruct" |
|
pipe = pipeline("kv-press-text-generation", model=ckpt, device=device, torch_dtype="auto") |
|
|
|
|
|
demo = gradio_interface() |
|
demo.launch() |
|
|