spawn99 commited on
Commit
6cdbea4
·
verified ·
1 Parent(s): 27d333e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -0
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import gradio as gr
4
+
5
+ # Load model and tokenizer once at startup
6
+ model_name = "Qwen/Qwen2.5-0.5B"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
9
+ print(f"Model loaded on: {model.device}")
10
+
11
+ # Define the generation function
12
+ def generate_text(prompt, max_new_tokens, num_beams):
13
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
14
+ input_length = inputs["input_ids"].shape[-1]
15
+ # Greedy search
16
+ start_time = time.time()
17
+ outputs_greedy = model.generate(
18
+ **inputs,
19
+ max_new_tokens=int(max_new_tokens),
20
+ num_beams=1,
21
+ do_sample=False,
22
+ )
23
+ greedy_time = time.time() - start_time
24
+ # Remove the prompt tokens from the output
25
+ generated_tokens_greedy = outputs_greedy[0][input_length:]
26
+ generated_text_greedy = tokenizer.decode(generated_tokens_greedy, skip_special_tokens=True)
27
+
28
+ # Beam search
29
+ start_time = time.time()
30
+ outputs_beam = model.generate(
31
+ **inputs,
32
+ num_beams=int(num_beams),
33
+ num_return_sequences=1,
34
+ max_new_tokens=int(max_new_tokens),
35
+ do_sample=False,
36
+ )
37
+ beam_time = time.time() - start_time
38
+ # Remove the prompt tokens as above
39
+ generated_tokens_beam = outputs_beam[0][input_length:]
40
+ generated_text_beam = tokenizer.decode(generated_tokens_beam, skip_special_tokens=True)
41
+
42
+ # Prepare outputs for better display formatting
43
+ greedy_details = (
44
+ f"**Strategy:** Picks the most probable token at each step (deterministic).\n\n"
45
+ f"**Time:** {greedy_time:.2f} seconds"
46
+ )
47
+
48
+ beam_details = (
49
+ f"**Strategy:** Explores {num_beams} beams concurrently and returns the top candidate.\n\n"
50
+ f"**Time:** {beam_time:.2f} seconds"
51
+ )
52
+
53
+ return greedy_details, generated_text_greedy, beam_details, generated_text_beam
54
+
55
+ with gr.Blocks() as demo:
56
+ # Informational header to help users understand the demo
57
+ gr.Markdown(
58
+ "# Beam Search Demo\n\n"
59
+ "This demo shows how two different text generation strategies work using the Qwen2.5-0.5B model. "
60
+ "The left side uses **greedy search**, which picks the most probable token at every generation step (deterministic), "
61
+ "while the right side uses **beam search**, which explores multiple beams concurrently to choose the most likely "
62
+ "sequence of tokens.\n\n"
63
+ "**Important:** This model works best with prompts that need completion rather than question-answering. For example, "
64
+ "instead of 'What is the capital of France?', use prompts like 'The capital of France is' or 'Here is a story about:'\n\n"
65
+ "Use the controls below to enter your prompt, adjust the maximum number of newly generated tokens, and set the "
66
+ "number of beams for beam search. The results for both strategies are displayed side by side for easy comparison.\n\n"
67
+ )
68
+
69
+ # Input components in a single column at the top
70
+ with gr.Column():
71
+ gr.Markdown("## Input")
72
+ prompt_input = gr.Textbox(label="Prompt", value="Here is a funny love letter for you:")
73
+ max_tokens_input = gr.Slider(minimum=1, maximum=100, step=1, label="Max new tokens", value=50)
74
+ num_beams_input = gr.Slider(minimum=1, maximum=20, step=1, label="Number of beams", value=10)
75
+ generate_btn = gr.Button("Generate")
76
+
77
+ with gr.Row():
78
+ with gr.Column():
79
+ greedy_details_output = gr.Markdown(label="Greedy Search Details")
80
+ greedy_textbox_output = gr.Textbox(label="Greedy Search Generated Text", lines=10)
81
+ with gr.Column():
82
+ beam_details_output = gr.Markdown(label="Beam Search Details")
83
+ beam_textbox_output = gr.Textbox(label="Beam Search Generated Text", lines=10)
84
+
85
+ # Connect the button click event to the generation function
86
+ generate_btn.click(
87
+ generate_text,
88
+ inputs=[prompt_input, max_tokens_input, num_beams_input],
89
+ outputs=[greedy_details_output, greedy_textbox_output, beam_details_output, beam_textbox_output]
90
+ )
91
+
92
+ if __name__ == "__main__":
93
+ demo.launch()