ki1207 commited on
Commit
9e66ec3
·
verified ·
1 Parent(s): 654ae11

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -251
app.py CHANGED
@@ -7,30 +7,24 @@ import string
7
 
8
  import gradio as gr
9
  import PIL.Image
10
- import spaces
11
  import torch
12
  from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
13
 
14
- DESCRIPTION = "# [BLIP-2](https://github.com/salesforce/LAVIS/tree/main/projects/blip2)"
15
 
16
  if not torch.cuda.is_available():
17
  DESCRIPTION += "\n<p>Running on CPU.</p>"
18
 
19
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
20
 
21
-
22
  MODEL_ID = "Salesforce/instructblip-flan-t5-xl"
23
 
24
-
25
-
26
  processor = InstructBlipProcessor.from_pretrained(MODEL_ID)
27
  model = InstructBlipForConditionalGeneration.from_pretrained(MODEL_ID).to(device)
28
 
29
-
30
-
31
-
32
- def generate_caption(
33
  image: PIL.Image.Image,
 
34
  decoding_method: str = "Nucleus sampling",
35
  temperature: float = 1.0,
36
  length_penalty: float = 1.0,
@@ -40,35 +34,15 @@ def generate_caption(
40
  num_beams: int = 5,
41
  top_p: float = 0.9,
42
  ) -> str:
43
- inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
44
- generated_ids = model.generate(
45
- pixel_values=inputs.pixel_values,
46
- do_sample=decoding_method == "Nucleus sampling",
47
- temperature=temperature,
48
- length_penalty=length_penalty,
49
- repetition_penalty=repetition_penalty,
50
- max_length=max_length,
51
- min_length=min_length,
52
- num_beams=num_beams,
53
- top_p=top_p,
54
- )
55
- result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
56
- return result
57
-
58
 
59
-
60
- def answer_question(
61
- image: PIL.Image.Image,
62
- prompt: str,
63
- decoding_method: str = "Nucleus sampling",
64
- temperature: float = 1.0,
65
- length_penalty: float = 1.0,
66
- repetition_penalty: float = 1.5,
67
- max_length: int = 50,
68
- min_length: int = 1,
69
- num_beams: int = 5,
70
- top_p: float = 0.9,
71
- ) -> str:
72
  inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)
73
  generated_ids = model.generate(
74
  **inputs,
@@ -84,239 +58,52 @@ def answer_question(
84
  result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
85
  return result
86
 
87
-
88
  def postprocess_output(output: str) -> str:
89
  if output and output[-1] not in string.punctuation:
90
  output += "."
91
  return output
92
 
93
-
94
- def chat(
95
- image: PIL.Image.Image,
96
- text: str,
97
- decoding_method: str = "Nucleus sampling",
98
- temperature: float = 1.0,
99
- length_penalty: float = 1.0,
100
- repetition_penalty: float = 1.5,
101
- max_length: int = 50,
102
- min_length: int = 1,
103
- num_beams: int = 5,
104
- top_p: float = 0.9,
105
- history_orig: list[str] = [],
106
- history_qa: list[str] = [],
107
- ) -> tuple[list[tuple[str, str]], list[str], list[str]]:
108
- history_orig.append(text)
109
- text_qa = f"Question: {text} Answer:"
110
- history_qa.append(text_qa)
111
- prompt = " ".join(history_qa)
112
-
113
- output = answer_question(
114
- image=image,
115
- prompt=prompt,
116
- decoding_method=decoding_method,
117
- temperature=temperature,
118
- length_penalty=length_penalty,
119
- repetition_penalty=repetition_penalty,
120
- max_length=max_length,
121
- min_length=min_length,
122
- num_beams=num_beams,
123
- top_p=top_p,
124
- )
125
- output = postprocess_output(output)
126
- history_orig.append(output)
127
- history_qa.append(output)
128
-
129
- chat_val = list(zip(history_orig[0::2], history_orig[1::2]))
130
- return chat_val, history_orig, history_qa
131
-
132
-
133
- examples = [
134
- [
135
- "images/house.png",
136
- "How could someone get out of the house?",
137
- ],
138
- [
139
- "images/flower.jpg",
140
- "What is this flower and where is it's origin?",
141
- ],
142
- [
143
- "images/pizza.jpg",
144
- "What are steps to cook it?",
145
- ],
146
- [
147
- "images/sunset.jpg",
148
- "Here is a romantic message going along the photo:",
149
- ],
150
- [
151
- "images/forbidden_city.webp",
152
- "In what dynasties was this place built?",
153
- ],
154
- ]
155
-
156
  with gr.Blocks() as demo:
157
  gr.Markdown(DESCRIPTION)
158
 
159
  with gr.Group():
 
160
  image = gr.Image(type="pil")
161
- with gr.Tabs():
162
- with gr.Tab(label="Image Captioning"):
163
- caption_button = gr.Button("Caption it!")
164
- caption_output = gr.Textbox(label="Caption Output", show_label=False, container=False)
165
- with gr.Tab(label="Visual Question Answering"):
166
- chatbot = gr.Chatbot(label="VQA Chat", show_label=False)
167
- history_orig = gr.State(value=[])
168
- history_qa = gr.State(value=[])
169
- vqa_input = gr.Text(label="Chat Input", show_label=False, max_lines=1, container=False)
170
- with gr.Row():
171
- clear_chat_button = gr.Button("Clear")
172
- chat_button = gr.Button("Submit", variant="primary")
173
- with gr.Accordion(label="Advanced settings", open=False):
174
- text_decoding_method = gr.Radio(
175
- label="Text Decoding Method",
176
- choices=["Beam search", "Nucleus sampling"],
177
- value="Nucleus sampling",
178
- )
179
- temperature = gr.Slider(
180
- label="Temperature",
181
- info="Used with nucleus sampling.",
182
- minimum=0.5,
183
- maximum=1.0,
184
- step=0.1,
185
- value=1.0,
186
- )
187
- length_penalty = gr.Slider(
188
- label="Length Penalty",
189
- info="Set to larger for longer sequence, used with beam search.",
190
- minimum=-1.0,
191
- maximum=2.0,
192
- step=0.2,
193
- value=1.0,
194
- )
195
- repetition_penalty = gr.Slider(
196
- label="Repetition Penalty",
197
- info="Larger value prevents repetition.",
198
- minimum=1.0,
199
- maximum=5.0,
200
- step=0.5,
201
- value=1.5,
202
- )
203
- max_length = gr.Slider(
204
- label="Max Length",
205
- minimum=20,
206
- maximum=512,
207
- step=1,
208
- value=50,
209
- )
210
- min_length = gr.Slider(
211
- label="Minimum Length",
212
- minimum=1,
213
- maximum=100,
214
- step=1,
215
- value=1,
216
- )
217
- num_beams = gr.Slider(
218
- label="Number of Beams",
219
- minimum=1,
220
- maximum=10,
221
- step=1,
222
- value=5,
223
- )
224
- top_p = gr.Slider(
225
- label="Top P",
226
- info="Used with nucleus sampling.",
227
- minimum=0.5,
228
- maximum=1.0,
229
- step=0.1,
230
- value=0.9,
231
- )
232
 
233
- gr.Examples(
234
- examples=examples,
235
- inputs=[image, vqa_input],
236
- outputs=caption_output,
237
- fn=generate_caption,
238
- )
239
 
240
- caption_button.click(
241
- fn=generate_caption,
 
 
 
 
 
 
242
  inputs=[
243
  image,
244
- text_decoding_method,
245
- temperature,
246
- length_penalty,
247
- repetition_penalty,
248
- max_length,
249
- min_length,
250
- num_beams,
251
- top_p,
 
252
  ],
253
- outputs=caption_output,
254
- api_name="caption",
255
  )
256
 
257
- chat_inputs = [
258
- image,
259
- vqa_input,
260
- text_decoding_method,
261
- temperature,
262
- length_penalty,
263
- repetition_penalty,
264
- max_length,
265
- min_length,
266
- num_beams,
267
- top_p,
268
- history_orig,
269
- history_qa,
270
- ]
271
- chat_outputs = [
272
- chatbot,
273
- history_orig,
274
- history_qa,
275
- ]
276
- vqa_input.submit(
277
- fn=chat,
278
- inputs=chat_inputs,
279
- outputs=chat_outputs,
280
- ).success(
281
- fn=lambda: "",
282
- outputs=vqa_input,
283
- queue=False,
284
- api_name=False,
285
- )
286
- chat_button.click(
287
- fn=chat,
288
- inputs=chat_inputs,
289
- outputs=chat_outputs,
290
- api_name="chat",
291
- ).success(
292
- fn=lambda: "",
293
- outputs=vqa_input,
294
- queue=False,
295
- api_name=False,
296
- )
297
- clear_chat_button.click(
298
- fn=lambda: ("", [], [], []),
299
- inputs=None,
300
- outputs=[
301
- vqa_input,
302
- chatbot,
303
- history_orig,
304
- history_qa,
305
- ],
306
- queue=False,
307
- api_name="clear",
308
- )
309
- image.change(
310
- fn=lambda: ("", [], [], []),
311
  inputs=None,
312
- outputs=[
313
- caption_output,
314
- chatbot,
315
- history_orig,
316
- history_qa,
317
- ],
318
  queue=False,
319
  )
320
 
321
  if __name__ == "__main__":
322
- demo.queue(max_size=10).launch()
 
7
 
8
  import gradio as gr
9
  import PIL.Image
 
10
  import torch
11
  from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
12
 
13
+ DESCRIPTION = "# [BLIP-2 VQA Ad Listing Analysis](https://github.com/salesforce/LAVIS/tree/main/projects/blip2)"
14
 
15
  if not torch.cuda.is_available():
16
  DESCRIPTION += "\n<p>Running on CPU.</p>"
17
 
18
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
19
 
 
20
  MODEL_ID = "Salesforce/instructblip-flan-t5-xl"
21
 
 
 
22
  processor = InstructBlipProcessor.from_pretrained(MODEL_ID)
23
  model = InstructBlipForConditionalGeneration.from_pretrained(MODEL_ID).to(device)
24
 
25
+ def answer_ad_listing_question(
 
 
 
26
  image: PIL.Image.Image,
27
+ title: str,
28
  decoding_method: str = "Nucleus sampling",
29
  temperature: float = 1.0,
30
  length_penalty: float = 1.0,
 
34
  num_beams: int = 5,
35
  top_p: float = 0.9,
36
  ) -> str:
37
+ # The prompt template with the provided title
38
+ prompt = f"""Given an ad listing with the title '{title}' and image, answer the following questions without any explanation or extra text:
39
+ Identify the species mentioned in the text, including specific names, e.g., 'Nile crocodile' instead of just 'crocodile'.
40
+ Select the product type from the following options: Animal fibers, Animal parts (bone or bone-like), Animal parts (fleshy), Coral product, Egg, Extract, Food, Ivory products, Live, Medicine, Nests, Organs and tissues, Powder, Scales or spines, Shells, Skin or leather products, Taxidermy, Insects.
41
+ The response should be in the format:
42
+ "Product Type: [type]
43
+ Species: [species]"
44
+ """
 
 
 
 
 
 
 
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)
47
  generated_ids = model.generate(
48
  **inputs,
 
58
  result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
59
  return result
60
 
 
61
  def postprocess_output(output: str) -> str:
62
  if output and output[-1] not in string.punctuation:
63
  output += "."
64
  return output
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  with gr.Blocks() as demo:
67
  gr.Markdown(DESCRIPTION)
68
 
69
  with gr.Group():
70
+ # Image and ad title input
71
  image = gr.Image(type="pil")
72
+ ad_title = gr.Textbox(label="Ad Title", placeholder="Enter the ad title here", lines=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
+ # Output section
75
+ answer_output = gr.Textbox(label="Ad Listing Analysis", show_label=True, placeholder="Response will appear here.")
 
 
 
 
76
 
77
+ # Submit and clear buttons
78
+ with gr.Row():
79
+ submit_button = gr.Button("Analyze Ad Listing", variant="primary")
80
+ clear_button = gr.Button("Clear")
81
+
82
+ # Logic to handle clicking on "Analyze Ad Listing"
83
+ submit_button.click(
84
+ fn=answer_ad_listing_question,
85
  inputs=[
86
  image,
87
+ ad_title, # The title from the ad
88
+ "Nucleus sampling", # Default values for decoding method, temperature, etc.
89
+ 1.0, # temperature
90
+ 1.0, # length_penalty
91
+ 1.5, # repetition_penalty
92
+ 50, # max_length
93
+ 1, # min_length
94
+ 5, # num_beams
95
+ 0.9, # top_p
96
  ],
97
+ outputs=answer_output,
 
98
  )
99
 
100
+ # Logic to handle clearing the inputs and outputs
101
+ clear_button.click(
102
+ fn=lambda: ("", "", ""), # Clear all the fields
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  inputs=None,
104
+ outputs=[image, ad_title, answer_output],
 
 
 
 
 
105
  queue=False,
106
  )
107
 
108
  if __name__ == "__main__":
109
+ demo.queue(max_size=10).launch()