Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
# Load model and tokenizer once at startup
|
6 |
+
model_name = "Qwen/Qwen2.5-0.5B"
|
7 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
8 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
|
9 |
+
print(f"Model loaded on: {model.device}")
|
10 |
+
|
11 |
+
# Define the generation function
|
12 |
+
def generate_text(prompt, max_new_tokens, num_beams):
|
13 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
14 |
+
input_length = inputs["input_ids"].shape[-1]
|
15 |
+
# Greedy search
|
16 |
+
start_time = time.time()
|
17 |
+
outputs_greedy = model.generate(
|
18 |
+
**inputs,
|
19 |
+
max_new_tokens=int(max_new_tokens),
|
20 |
+
num_beams=1,
|
21 |
+
do_sample=False,
|
22 |
+
)
|
23 |
+
greedy_time = time.time() - start_time
|
24 |
+
# Remove the prompt tokens from the output
|
25 |
+
generated_tokens_greedy = outputs_greedy[0][input_length:]
|
26 |
+
generated_text_greedy = tokenizer.decode(generated_tokens_greedy, skip_special_tokens=True)
|
27 |
+
|
28 |
+
# Beam search
|
29 |
+
start_time = time.time()
|
30 |
+
outputs_beam = model.generate(
|
31 |
+
**inputs,
|
32 |
+
num_beams=int(num_beams),
|
33 |
+
num_return_sequences=1,
|
34 |
+
max_new_tokens=int(max_new_tokens),
|
35 |
+
do_sample=False,
|
36 |
+
)
|
37 |
+
beam_time = time.time() - start_time
|
38 |
+
# Remove the prompt tokens as above
|
39 |
+
generated_tokens_beam = outputs_beam[0][input_length:]
|
40 |
+
generated_text_beam = tokenizer.decode(generated_tokens_beam, skip_special_tokens=True)
|
41 |
+
|
42 |
+
# Prepare outputs for better display formatting
|
43 |
+
greedy_details = (
|
44 |
+
f"**Strategy:** Picks the most probable token at each step (deterministic).\n\n"
|
45 |
+
f"**Time:** {greedy_time:.2f} seconds"
|
46 |
+
)
|
47 |
+
|
48 |
+
beam_details = (
|
49 |
+
f"**Strategy:** Explores {num_beams} beams concurrently and returns the top candidate.\n\n"
|
50 |
+
f"**Time:** {beam_time:.2f} seconds"
|
51 |
+
)
|
52 |
+
|
53 |
+
return greedy_details, generated_text_greedy, beam_details, generated_text_beam
|
54 |
+
|
55 |
+
with gr.Blocks() as demo:
|
56 |
+
# Informational header to help users understand the demo
|
57 |
+
gr.Markdown(
|
58 |
+
"# Beam Search Demo\n\n"
|
59 |
+
"This demo shows how two different text generation strategies work using the Qwen2.5-0.5B model. "
|
60 |
+
"The left side uses **greedy search**, which picks the most probable token at every generation step (deterministic), "
|
61 |
+
"while the right side uses **beam search**, which explores multiple beams concurrently to choose the most likely "
|
62 |
+
"sequence of tokens.\n\n"
|
63 |
+
"**Important:** This model works best with prompts that need completion rather than question-answering. For example, "
|
64 |
+
"instead of 'What is the capital of France?', use prompts like 'The capital of France is' or 'Here is a story about:'\n\n"
|
65 |
+
"Use the controls below to enter your prompt, adjust the maximum number of newly generated tokens, and set the "
|
66 |
+
"number of beams for beam search. The results for both strategies are displayed side by side for easy comparison.\n\n"
|
67 |
+
)
|
68 |
+
|
69 |
+
# Input components in a single column at the top
|
70 |
+
with gr.Column():
|
71 |
+
gr.Markdown("## Input")
|
72 |
+
prompt_input = gr.Textbox(label="Prompt", value="Here is a funny love letter for you:")
|
73 |
+
max_tokens_input = gr.Slider(minimum=1, maximum=100, step=1, label="Max new tokens", value=50)
|
74 |
+
num_beams_input = gr.Slider(minimum=1, maximum=20, step=1, label="Number of beams", value=10)
|
75 |
+
generate_btn = gr.Button("Generate")
|
76 |
+
|
77 |
+
with gr.Row():
|
78 |
+
with gr.Column():
|
79 |
+
greedy_details_output = gr.Markdown(label="Greedy Search Details")
|
80 |
+
greedy_textbox_output = gr.Textbox(label="Greedy Search Generated Text", lines=10)
|
81 |
+
with gr.Column():
|
82 |
+
beam_details_output = gr.Markdown(label="Beam Search Details")
|
83 |
+
beam_textbox_output = gr.Textbox(label="Beam Search Generated Text", lines=10)
|
84 |
+
|
85 |
+
# Connect the button click event to the generation function
|
86 |
+
generate_btn.click(
|
87 |
+
generate_text,
|
88 |
+
inputs=[prompt_input, max_tokens_input, num_beams_input],
|
89 |
+
outputs=[greedy_details_output, greedy_textbox_output, beam_details_output, beam_textbox_output]
|
90 |
+
)
|
91 |
+
|
92 |
+
if __name__ == "__main__":
|
93 |
+
demo.launch()
|