File size: 4,966 Bytes
d75eb4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4bf4e38
d75eb4f
 
 
 
 
4bf4e38
d75eb4f
 
 
 
 
 
 
 
 
 
 
4bf4e38
d75eb4f
 
 
 
4bf4e38
d75eb4f
 
 
 
 
 
 
 
 
3ff5cda
d75eb4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4bf4e38
d75eb4f
 
4bf4e38
 
d75eb4f
 
 
 
 
 
 
d84c61c
d75eb4f
 
 
4bf4e38
 
 
 
 
ff1672e
 
 
 
 
 
4bf4e38
d75eb4f
 
 
 
 
 
 
4bf4e38
d75eb4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# This space is mostly a copy of the work of Aritra Roy Gosthipaty (see https://huggingface.co/spaces/ariG23498/kv-press/blob/main/app.py)

import spaces
import requests
import gradio as gr
from bs4 import BeautifulSoup
from transformers import pipeline

from kvpress import (
    ExpectedAttentionPress,
    KnormPress,
    RandomPress,
    SnapKVPress,
    StreamingLLMPress,
    TOVAPress,
)

press_dict = {
    "ExpectedAttentionPress": ExpectedAttentionPress,
    "KnormPress": KnormPress,
    "RandomPress": RandomPress,
    "SnapKVPress": SnapKVPress,
    "StreamingLLMPress": StreamingLLMPress,
    "TOVAPress": TOVAPress,
}


@spaces.GPU
def process_request(url, question, press_name, compression_ratio):
    """ """

    if press_name not in press_dict:
        return f"Invalid press type selected: {press_name}", -1, -1

    # Fetch the Wikipedia article
    try:
        content = requests.get(url).content
    except requests.exceptions.RequestException as e:
        return f"Error fetching the Wikipedia article: {str(e)}", -1, -1

    try:
        # Parse the Wikipedia HTML
        soup = BeautifulSoup(content, "html.parser")
        context = "".join([p.text for p in soup.find_all("p")]) + "\n\n"

        # Initialize the press
        press = press_dict[press_name](compression_ratio)
        num_tokens = pipe.tokenizer(context, return_tensors="pt")["input_ids"].shape[1]
        pred_answer = pipe(context, question=question, press=press)["answer"]

        return pred_answer, num_tokens, int(num_tokens * (1 - compression_ratio))
    except Exception as e:
        if "CUDA out of memory" in str(e):
            return "Error: CUDA out of memory. Try using a smaller article or a lower compression ratio.", -1
        else:
            return str(e), -1, -1


def gradio_interface():
    with gr.Blocks() as demo:
        gr.Markdown(
            """
            # Wikipedia Article Question Answering with kvpress
            This demo uses the llama 3.1 8B Instruct model to answer questions about any given Wikipedia article.  
            Under the hood, [kvpress](https://github.com/NVIDIA/kvpress) *compresses the key-value (KV) cache* associated with the article, helping reduce memory usage and accelerate decoding.

            **How to use:**
            1. Enter a Wikipedia article URL 
            2. Type your question
            3. Select a press type and the desired compression ratio
            4. Press "Submit" to see the answer, along with token statistics before and after compression
            """
        )

        with gr.Row():
            url_input = gr.Textbox(label="Wikipedia Article URL", placeholder="Enter the Wikipedia article URL here")
            question_input = gr.Textbox(label="Question", placeholder="Type your question here")

        with gr.Row():
            press_selector = gr.Dropdown(
                choices=list(press_dict.keys()),
                value="ExpectedAttentionPress",
                label="Select Press Type",
            )
            compression_slider = gr.Slider(minimum=0.0, maximum=0.9, step=0.1, value=0.5, label="Compression Ratio")

        output = gr.Textbox(label="Output", lines=10)
        output_num_tokens = gr.Number(label="Number of tokens before compression", interactive=False)
        output_compressed_num_tokens = gr.Number(label="Number of tokens after compression", interactive=False)

        submit_button = gr.Button("Submit")

        gr.Examples(
            examples=[
                [
                    "https://en.wikipedia.org/wiki/Nvidia",
                    "Complete this sentence: In May 2017, the program had 1,300 companies. As of March 2018, there were ",
                    "ExpectedAttentionPress",
                    0.5,
                ],
                [
                    "https://en.wikipedia.org/wiki/Hugging_Face",
                    "What was the original name of the transformers library ?",
                    "ExpectedAttentionPress",
                    0.5,
                ],
                                [
                    "https://en.wikipedia.org/wiki/World_Chess_Championship_2024",
                    "On which move did the world chess championship end?",
                    "ExpectedAttentionPress",
                    0.5,
                ],
            ],
            inputs=[url_input, question_input, press_selector, compression_slider],
        )

        submit_button.click(
            process_request,
            inputs=[url_input, question_input, press_selector, compression_slider],
            outputs=[output, output_num_tokens, output_compressed_num_tokens],
        )

    return demo


if __name__ == "__main__":

    # Load pipeline
    device = "cuda:0"
    ckpt = "meta-llama/Meta-Llama-3.1-8B-Instruct"
    pipe = pipeline("kv-press-text-generation", model=ckpt, device=device, torch_dtype="auto")

    # Launch demo
    demo = gradio_interface()
    demo.launch()