merve HF staff commited on
Commit
36be50d
·
verified ·
1 Parent(s): e4c787e

Bring back chatbot

Browse files
Files changed (1) hide show
  1. app.py +93 -128
app.py CHANGED
@@ -1,11 +1,12 @@
1
  import gradio as gr
2
- from transformers import AutoProcessor, AutoModelForVision2Seq
 
3
  import re
4
  import time
5
  from PIL import Image
6
  import torch
7
  import spaces
8
- import subprocess
9
  #subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
10
 
11
 
@@ -17,36 +18,37 @@ model = AutoModelForVision2Seq.from_pretrained("HuggingFaceTB/SmolVLM-Instruct",
17
 
18
  @spaces.GPU
19
  def model_inference(
20
- images, text, assistant_prefix, decoding_strategy, temperature, max_new_tokens,
21
  repetition_penalty, top_p
22
- ):
 
 
 
 
 
 
 
 
23
  if text == "" and not images:
24
  gr.Error("Please input a query and optionally image(s).")
25
 
26
  if text == "" and images:
27
  gr.Error("Please input a text query along the image(s).")
28
 
29
- if isinstance(images, Image.Image):
30
- images = [images]
31
 
32
 
33
  resulting_messages = [
34
  {
35
  "role": "user",
36
- "content": [{"type": "image"}] + [
37
  {"type": "text", "text": text}
38
  ]
39
  }
40
  ]
41
-
42
- if assistant_prefix:
43
- text = f"{assistant_prefix} {text}"
44
-
45
-
46
  prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
47
  inputs = processor(text=prompt, images=[images], return_tensors="pt")
48
  inputs = {k: v.to("cuda") for k, v in inputs.items()}
49
-
50
  generation_args = {
51
  "max_new_tokens": max_new_tokens,
52
  "repetition_penalty": repetition_penalty,
@@ -65,119 +67,82 @@ def model_inference(
65
  generation_args["top_p"] = top_p
66
 
67
  generation_args.update(inputs)
68
-
69
  # Generate
70
- generated_ids = model.generate(**generation_args)
71
-
72
- generated_texts = processor.batch_decode(generated_ids[:, generation_args["input_ids"].size(1):], skip_special_tokens=True)
73
- return generated_texts[0]
74
-
75
-
76
- with gr.Blocks(fill_height=False) as demo:
77
- gr.Markdown("## SmolVLM: Small yet Mighty 💫")
78
- gr.Markdown("Play with [HuggingFaceTB/SmolVLM-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM-Instruct) in this demo. To get started, upload an image and text or try one of the examples.")
79
- with gr.Column():
80
- with gr.Row():
81
- image_input = gr.Image(label="Upload your Image", type="pil")
82
-
83
- with gr.Column():
84
- query_input = gr.Textbox(label="Prompt")
85
- assistant_prefix = gr.Textbox(label="Assistant Prefix", placeholder="Let's think step by step.")
86
-
87
- submit_btn = gr.Button("Submit")
88
- output = gr.Textbox(label="Output")
89
-
90
-
91
- with gr.Accordion(label="Advanced Generation Parameters", open=False):
92
- examples=[
93
- ["example_images/rococo.jpg", "What art era is this?", "", "Top P Sampling", 0.4, 512, 1.2, 0.8],
94
- ["example_images/examples_wat_arun.jpg", "I'm planning a visit to this temple, give me travel tips.", "", "Greedy", 0.4, 512, 1.2, 0.8],
95
- ["example_images/examples_invoice.png", "What is the due date and the invoice date?", "", "Top P Sampling", 0.4, 512, 1.2, 0.8],
96
- ["example_images/s2w_example.png", "What is this UI about?", "", "Top P Sampling", 0.4, 512, 1.2, 0.8],
97
- ["example_images/examples_weather_events.png", "Where do the severe droughts happen according to this diagram?", "", "Top P Sampling", 0.4, 512, 1.2, 0.8],
98
- ]
99
- # Hyper-parameters for generation
100
- max_new_tokens = gr.Slider(
101
- minimum=8,
102
- maximum=1024,
103
- value=512,
104
- step=1,
105
- interactive=True,
106
- label="Maximum number of new tokens to generate",
107
- )
108
- repetition_penalty = gr.Slider(
109
- minimum=0.01,
110
- maximum=5.0,
111
- value=1.2,
112
- step=0.01,
113
- interactive=True,
114
- label="Repetition penalty",
115
- info="1.0 is equivalent to no penalty",
116
- )
117
- temperature = gr.Slider(
118
- minimum=0.0,
119
- maximum=5.0,
120
- value=0.4,
121
- step=0.1,
122
- interactive=True,
123
- label="Sampling temperature",
124
- info="Higher values will produce more diverse outputs.",
125
- )
126
- top_p = gr.Slider(
127
- minimum=0.01,
128
- maximum=0.99,
129
- value=0.8,
130
- step=0.01,
131
- interactive=True,
132
- label="Top P",
133
- info="Higher values is equivalent to sampling more low-probability tokens.",
134
- )
135
- decoding_strategy = gr.Radio(
136
- [
137
- "Top P Sampling",
138
- "Greedy",
139
-
140
- ],
141
- value="Top P Sampling",
142
- label="Decoding strategy",
143
- interactive=True,
144
- info="Higher values is equivalent to sampling more low-probability tokens.",
145
- )
146
- decoding_strategy.change(
147
- fn=lambda selection: gr.Slider(
148
- visible=(
149
- selection in ["contrastive_sampling", "beam_sampling", "Top P Sampling", "sampling_top_k"]
150
- )
151
- ),
152
- inputs=decoding_strategy,
153
- outputs=temperature,
154
- )
155
-
156
- decoding_strategy.change(
157
- fn=lambda selection: gr.Slider(
158
- visible=(
159
- selection in ["contrastive_sampling", "beam_sampling", "Top P Sampling", "sampling_top_k"]
160
- )
161
- ),
162
- inputs=decoding_strategy,
163
- outputs=repetition_penalty,
164
- )
165
- decoding_strategy.change(
166
- fn=lambda selection: gr.Slider(visible=(selection in ["Top P Sampling"])),
167
- inputs=decoding_strategy,
168
- outputs=top_p,
169
- )
170
- gr.Examples(
171
- examples = examples,
172
- inputs=[image_input, query_input, assistant_prefix, decoding_strategy, temperature,
173
- max_new_tokens, repetition_penalty, top_p],
174
- outputs=output,
175
- fn=model_inference
176
- )
177
-
178
-
179
- submit_btn.click(model_inference, inputs = [image_input, query_input, assistant_prefix, decoding_strategy, temperature,
180
- max_new_tokens, repetition_penalty, top_p], outputs=output)
181
-
182
-
183
- demo.launch(debug=True)
 
1
  import gradio as gr
2
+ from transformers import AutoProcessor, AutoModelForVision2Seq, TextIteratorStreamer
3
+ from threading import Thread
4
  import re
5
  import time
6
  from PIL import Image
7
  import torch
8
  import spaces
9
+ #import subprocess
10
  #subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
11
 
12
 
 
18
 
19
  @spaces.GPU
20
  def model_inference(
21
+ input_dict, history, decoding_strategy, temperature, max_new_tokens,
22
  repetition_penalty, top_p
23
+ ):
24
+ text = input_dict["text"]
25
+ print(input_dict["files"])
26
+ if len(input_dict["files"]) > 1:
27
+ images = [Image.open(image).convert("RGB") for image in input_dict["files"]]
28
+ elif len(input_dict["files"]) == 1:
29
+ images = [Image.open(input_dict["files"][0]).convert("RGB")]
30
+
31
+
32
  if text == "" and not images:
33
  gr.Error("Please input a query and optionally image(s).")
34
 
35
  if text == "" and images:
36
  gr.Error("Please input a text query along the image(s).")
37
 
38
+
 
39
 
40
 
41
  resulting_messages = [
42
  {
43
  "role": "user",
44
+ "content": [{"type": "image"} for _ in range(len(images))] + [
45
  {"type": "text", "text": text}
46
  ]
47
  }
48
  ]
 
 
 
 
 
49
  prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
50
  inputs = processor(text=prompt, images=[images], return_tensors="pt")
51
  inputs = {k: v.to("cuda") for k, v in inputs.items()}
 
52
  generation_args = {
53
  "max_new_tokens": max_new_tokens,
54
  "repetition_penalty": repetition_penalty,
 
67
  generation_args["top_p"] = top_p
68
 
69
  generation_args.update(inputs)
 
70
  # Generate
71
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens= True)
72
+ generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
73
+ generated_text = ""
74
+
75
+ thread = Thread(target=model.generate, kwargs=generation_args)
76
+ thread.start()
77
+ thread.join()
78
+
79
+ buffer = ""
80
+
81
+
82
+ for new_text in streamer:
83
+
84
+ buffer += new_text
85
+ generated_text_without_prompt = buffer#[len(ext_buffer):]
86
+ time.sleep(0.01)
87
+ yield buffer
88
+
89
+
90
+ examples=[
91
+ [{"text": "What art era do these artpieces belong to?", "files": ["example_images/rococo.jpg", "example_images/rococo_1.jpg"]}, "Greedy", 0.4, 512, 1.2, 0.8],
92
+ [{"text": "I'm planning a visit to this temple, give me travel tips.", "files": ["example_images/examples_wat_arun.jpg"]}, "Greedy", 0.4, 512, 1.2, 0.8],
93
+ [{"text": "What is the due date and the invoice date?", "files": ["example_images/examples_invoice.png"]}, "Greedy", 0.4, 512, 1.2, 0.8],
94
+ [{"text": "What is this UI about?", "files": ["example_images/s2w_example.png"]}, "Greedy", 0.4, 512, 1.2, 0.8],
95
+ [{"text": "Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}, "Greedy", 0.4, 512, 1.2, 0.8],
96
+ ]
97
+ demo = gr.ChatInterface(fn=model_inference, title="SmolVLM: Small yet Mighty 💫",
98
+ description="Play with [HuggingFaceTB/SmolVLM-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM-Instruct) in this demo. To get started, upload an image and text or try one of the examples. This checkpoint works best with single turn conversations, so clear the conversation after a single turn.",
99
+ examples=examples,
100
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"), stop_btn="Stop Generation", multimodal=True,
101
+ additional_inputs=[gr.Radio(["Top P Sampling",
102
+ "Greedy"],
103
+ value="Greedy",
104
+ label="Decoding strategy",
105
+ #interactive=True,
106
+ info="Higher values is equivalent to sampling more low-probability tokens.",
107
+
108
+ ), gr.Slider(
109
+ minimum=0.0,
110
+ maximum=5.0,
111
+ value=0.4,
112
+ step=0.1,
113
+ interactive=True,
114
+ label="Sampling temperature",
115
+ info="Higher values will produce more diverse outputs.",
116
+ ),
117
+ gr.Slider(
118
+ minimum=8,
119
+ maximum=1024,
120
+ value=512,
121
+ step=1,
122
+ interactive=True,
123
+ label="Maximum number of new tokens to generate",
124
+ ), gr.Slider(
125
+ minimum=0.01,
126
+ maximum=5.0,
127
+ value=1.2,
128
+ step=0.01,
129
+ interactive=True,
130
+ label="Repetition penalty",
131
+ info="1.0 is equivalent to no penalty",
132
+ ),
133
+ gr.Slider(
134
+ minimum=0.01,
135
+ maximum=0.99,
136
+ value=0.8,
137
+ step=0.01,
138
+ interactive=True,
139
+ label="Top P",
140
+ info="Higher values is equivalent to sampling more low-probability tokens.",
141
+ )],cache_examples=False
142
+ )
143
+
144
+
145
+
146
+
147
+ demo.launch(debug=True)
148
+