File size: 4,966 Bytes
d75eb4f 4bf4e38 d75eb4f 4bf4e38 d75eb4f 4bf4e38 d75eb4f 4bf4e38 d75eb4f 3ff5cda d75eb4f 4bf4e38 d75eb4f 4bf4e38 d75eb4f d84c61c d75eb4f 4bf4e38 ff1672e 4bf4e38 d75eb4f 4bf4e38 d75eb4f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
# This space is mostly a copy of the work of Aritra Roy Gosthipaty (see https://huggingface.co/spaces/ariG23498/kv-press/blob/main/app.py)
import spaces
import requests
import gradio as gr
from bs4 import BeautifulSoup
from transformers import pipeline
from kvpress import (
ExpectedAttentionPress,
KnormPress,
RandomPress,
SnapKVPress,
StreamingLLMPress,
TOVAPress,
)
press_dict = {
"ExpectedAttentionPress": ExpectedAttentionPress,
"KnormPress": KnormPress,
"RandomPress": RandomPress,
"SnapKVPress": SnapKVPress,
"StreamingLLMPress": StreamingLLMPress,
"TOVAPress": TOVAPress,
}
@spaces.GPU
def process_request(url, question, press_name, compression_ratio):
""" """
if press_name not in press_dict:
return f"Invalid press type selected: {press_name}", -1, -1
# Fetch the Wikipedia article
try:
content = requests.get(url).content
except requests.exceptions.RequestException as e:
return f"Error fetching the Wikipedia article: {str(e)}", -1, -1
try:
# Parse the Wikipedia HTML
soup = BeautifulSoup(content, "html.parser")
context = "".join([p.text for p in soup.find_all("p")]) + "\n\n"
# Initialize the press
press = press_dict[press_name](compression_ratio)
num_tokens = pipe.tokenizer(context, return_tensors="pt")["input_ids"].shape[1]
pred_answer = pipe(context, question=question, press=press)["answer"]
return pred_answer, num_tokens, int(num_tokens * (1 - compression_ratio))
except Exception as e:
if "CUDA out of memory" in str(e):
return "Error: CUDA out of memory. Try using a smaller article or a lower compression ratio.", -1
else:
return str(e), -1, -1
def gradio_interface():
with gr.Blocks() as demo:
gr.Markdown(
"""
# Wikipedia Article Question Answering with kvpress
This demo uses the llama 3.1 8B Instruct model to answer questions about any given Wikipedia article.
Under the hood, [kvpress](https://github.com/NVIDIA/kvpress) *compresses the key-value (KV) cache* associated with the article, helping reduce memory usage and accelerate decoding.
**How to use:**
1. Enter a Wikipedia article URL
2. Type your question
3. Select a press type and the desired compression ratio
4. Press "Submit" to see the answer, along with token statistics before and after compression
"""
)
with gr.Row():
url_input = gr.Textbox(label="Wikipedia Article URL", placeholder="Enter the Wikipedia article URL here")
question_input = gr.Textbox(label="Question", placeholder="Type your question here")
with gr.Row():
press_selector = gr.Dropdown(
choices=list(press_dict.keys()),
value="ExpectedAttentionPress",
label="Select Press Type",
)
compression_slider = gr.Slider(minimum=0.0, maximum=0.9, step=0.1, value=0.5, label="Compression Ratio")
output = gr.Textbox(label="Output", lines=10)
output_num_tokens = gr.Number(label="Number of tokens before compression", interactive=False)
output_compressed_num_tokens = gr.Number(label="Number of tokens after compression", interactive=False)
submit_button = gr.Button("Submit")
gr.Examples(
examples=[
[
"https://en.wikipedia.org/wiki/Nvidia",
"Complete this sentence: In May 2017, the program had 1,300 companies. As of March 2018, there were ",
"ExpectedAttentionPress",
0.5,
],
[
"https://en.wikipedia.org/wiki/Hugging_Face",
"What was the original name of the transformers library ?",
"ExpectedAttentionPress",
0.5,
],
[
"https://en.wikipedia.org/wiki/World_Chess_Championship_2024",
"On which move did the world chess championship end?",
"ExpectedAttentionPress",
0.5,
],
],
inputs=[url_input, question_input, press_selector, compression_slider],
)
submit_button.click(
process_request,
inputs=[url_input, question_input, press_selector, compression_slider],
outputs=[output, output_num_tokens, output_compressed_num_tokens],
)
return demo
if __name__ == "__main__":
# Load pipeline
device = "cuda:0"
ckpt = "meta-llama/Meta-Llama-3.1-8B-Instruct"
pipe = pipeline("kv-press-text-generation", model=ckpt, device=device, torch_dtype="auto")
# Launch demo
demo = gradio_interface()
demo.launch()
|