Niki Zhang commited on
Commit
f2e1e32
·
verified ·
1 Parent(s): c5a524a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -38
app.py CHANGED
@@ -3,7 +3,7 @@ import json
3
  import gradio as gr
4
  import numpy as np
5
  from gradio import processing_utils
6
-
7
  from packaging import version
8
  from PIL import Image, ImageDraw
9
  import functools
@@ -171,28 +171,6 @@ def upload_callback(image_input, state, visual_chatgpt=None):
171
  return state, state, image_input, click_state, image_input, image_input, image_embedding, \
172
  original_size, input_size
173
 
174
- def store_click(image_input, point_prompt, click_mode, state, click_state, evt: gr.SelectData):
175
- click_index = evt.index
176
- if point_prompt == 'Positive':
177
- coordinate = [click_index[0], click_index[1], 1]
178
- else:
179
- coordinate = [click_index[0], click_index[1], 0]
180
-
181
- if click_mode == 'Continuous':
182
- click_state[0].append(coordinate)
183
- elif click_mode == 'Single':
184
- click_state[0] = [coordinate] # Overwrite with latest click
185
-
186
- return state, click_state
187
-
188
- def generate_caption(image_input, enable_wiki, language, sentiment, factuality, length, image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt):
189
- last_click = click_state[0][-1]
190
- point_prompt = 'Positive' if last_click[2] == 1 else 'Negative'
191
-
192
- return inference_click(image_input, point_prompt, click_mode, enable_wiki, language, sentiment, factuality, length, image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt, gr.SelectData(index=(last_click[0], last_click[1])))
193
-
194
-
195
-
196
 
197
 
198
  def inference_click(image_input, point_prompt, click_mode, enable_wiki, language, sentiment, factuality,
@@ -230,14 +208,13 @@ def inference_click(image_input, point_prompt, click_mode, enable_wiki, language
230
  out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki, verbose=True, args={'clip_filter': False})[0]
231
 
232
  state = state + [("Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]), None)]
233
- state = state + [(None, "raw_caption: {}".format(out['generated_captions']['raw_caption']))]
234
  update_click_state(click_state, out['generated_captions']['raw_caption'], click_mode)
235
  text = out['generated_captions']['raw_caption']
236
  input_mask = np.array(out['mask'].convert('P'))
237
  image_input = mask_painter(np.array(image_input), input_mask)
238
  origin_image_input = image_input
239
- image_input = create_bubble_frame(image_input, text, (click_index[0], click_index[1]), input_mask,
240
- input_points=input_points, input_labels=input_labels)
241
  x, y = input_points[-1]
242
 
243
  if visual_chatgpt is not None:
@@ -247,19 +224,62 @@ def inference_click(image_input, point_prompt, click_mode, enable_wiki, language
247
  point_prompt = f'You should primarly use tools on the selected regional image (description: {text}, path: {new_crop_save_path}), which is a part of the whole image (path: {visual_chatgpt.current_image}). If human mentioned some objects not in the selected region, you can use tools on the whole image.'
248
  visual_chatgpt.point_prompt = point_prompt
249
 
250
- yield state, state, click_state, image_input
 
 
 
 
 
251
  if not args.disable_gpt and model.text_refiner:
252
  refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'],
253
  enable_wiki=enable_wiki)
254
- # new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
255
  new_cap = refined_caption['caption']
256
  if refined_caption['wiki']:
257
  state = state + [(None, "Wiki: {}".format(refined_caption['wiki']))]
258
  state = state + [(None, f"caption: {new_cap}")]
259
- refined_image_input = create_bubble_frame(origin_image_input, new_cap, (click_index[0], click_index[1]),
260
- input_mask,
261
- input_points=input_points, input_labels=input_labels)
262
- yield state, state, click_state, refined_image_input
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
 
265
  def get_sketch_prompt(mask: Image.Image):
@@ -322,7 +342,7 @@ def inference_traject(sketcher_image, enable_wiki, language, sentiment, factuali
322
  origin_image_input = image_input
323
 
324
  fake_click_index = (int((boxes[0][0] + boxes[0][2]) / 2), int((boxes[0][1] + boxes[0][3]) / 2))
325
- image_input = create_bubble_frame(image_input, text, fake_click_index, input_mask)
326
 
327
  yield state, state, image_input
328
 
@@ -415,6 +435,7 @@ def create_ui():
415
  visual_chatgpt = gr.State(None)
416
  original_size = gr.State(None)
417
  input_size = gr.State(None)
 
418
  # img_caption = gr.State(None)
419
  aux_state = gr.State([])
420
 
@@ -442,7 +463,7 @@ def create_ui():
442
  with gr.Row(scale=0.4):
443
  clear_button_click = gr.Button(value="Clear Clicks", interactive=True)
444
  clear_button_image = gr.Button(value="Clear Image", interactive=True)
445
- submit_button_click = gr.Button(value="Submit", interactive=True)
446
  with gr.Tab("Trajectory (beta)"):
447
  sketcher_input = ImageSketcher(type="pil", interactive=True, brush_radius=20,
448
  elem_id="image_sketcher")
@@ -566,8 +587,6 @@ def create_ui():
566
  )
567
  clear_button_text.click(clear_chat_memory, inputs=[visual_chatgpt])
568
 
569
- submit_button_click.click(generate_caption, inputs=[origin_image, enable_wiki, language, sentiment, factuality, length, image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt], outputs=[chatbot, state, click_state, image_input])
570
-
571
  image_input.clear(
572
  lambda: (None, [], [], [[], [], []], "", "", ""),
573
  [],
@@ -595,8 +614,25 @@ def create_ui():
595
  [chatbot, state, origin_image, click_state, image_input, sketcher_input,
596
  image_embedding, original_size, input_size])
597
  example_image.change(clear_chat_memory, inputs=[visual_chatgpt])
598
- # select coordinate
599
- image_input.select(store_click, inputs=[origin_image, point_prompt, click_mode, state, click_state], outputs=[state, click_state])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
600
 
601
  submit_button_sketcher.click(
602
  inference_traject,
 
3
  import gradio as gr
4
  import numpy as np
5
  from gradio import processing_utils
6
+ import requests
7
  from packaging import version
8
  from PIL import Image, ImageDraw
9
  import functools
 
171
  return state, state, image_input, click_state, image_input, image_input, image_embedding, \
172
  original_size, input_size
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
 
176
  def inference_click(image_input, point_prompt, click_mode, enable_wiki, language, sentiment, factuality,
 
208
  out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki, verbose=True, args={'clip_filter': False})[0]
209
 
210
  state = state + [("Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]), None)]
 
211
  update_click_state(click_state, out['generated_captions']['raw_caption'], click_mode)
212
  text = out['generated_captions']['raw_caption']
213
  input_mask = np.array(out['mask'].convert('P'))
214
  image_input = mask_painter(np.array(image_input), input_mask)
215
  origin_image_input = image_input
216
+ # image_input = create_bubble_frame(image_input, None, (click_index[0], click_index[1]), input_mask,
217
+ # input_points=input_points, input_labels=input_labels)
218
  x, y = input_points[-1]
219
 
220
  if visual_chatgpt is not None:
 
224
  point_prompt = f'You should primarly use tools on the selected regional image (description: {text}, path: {new_crop_save_path}), which is a part of the whole image (path: {visual_chatgpt.current_image}). If human mentioned some objects not in the selected region, you can use tools on the whole image.'
225
  visual_chatgpt.point_prompt = point_prompt
226
 
227
+
228
+ generated_caption = text
229
+ print(generated_caption)
230
+
231
+ yield state, state, click_state, image_input, generated_caption
232
+
233
  if not args.disable_gpt and model.text_refiner:
234
  refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'],
235
  enable_wiki=enable_wiki)
 
236
  new_cap = refined_caption['caption']
237
  if refined_caption['wiki']:
238
  state = state + [(None, "Wiki: {}".format(refined_caption['wiki']))]
239
  state = state + [(None, f"caption: {new_cap}")]
240
+ # refined_image_input = create_bubble_frame(origin_image_input, None, (click_index[0], click_index[1]),
241
+ # input_mask,
242
+ # input_points=input_points, input_labels=input_labels)
243
+ yield state, state, click_state, image_input, new_cap
244
+
245
+ def submit_caption(image_input, state,generated_caption):
246
+ print(state)
247
+ if state and isinstance(state[-1][1], dict):
248
+ params = state[-1][1]
249
+ else:
250
+ params = {}
251
+
252
+ click_index = params.get("click_index", (0, 0))
253
+ input_mask = params.get("input_mask", np.zeros((1, 1)))
254
+ input_points = params.get("input_points", [])
255
+ input_labels = params.get("input_labels", [])
256
+
257
+ click_index = params.get("click_index", (0, 0))
258
+ input_mask = params.get("input_mask", np.zeros((1, 1)))
259
+ input_points = params.get("input_points", [])
260
+ input_labels = params.get("input_labels", [])
261
+
262
+ image_input = create_bubble_frame(np.array(image_input), generated_caption, (click_index[0], click_index[1]), input_mask,
263
+ input_points=input_points, input_labels=input_labels)
264
+
265
+
266
+ if generated_caption:
267
+ state = state + [(None, f"RAW_Caption: {generated_caption}")]
268
+ txt2speech(generated_caption)
269
+
270
+ yield state,state,image_input
271
+
272
+
273
+ def txt2speech(text):
274
+ print("Initializing text-to-speech conversion...")
275
+ # API_URL = "https://api-inference.huggingface.co/models/espnet/kan-bayashi_ljspeech_vits"
276
+ # headers = {"Authorization": f"Bearer {os.environ['HUGGINGFACEHUB_API_TOKEN']}"}
277
+ # payloads = {'inputs': text}
278
+ # response = requests.post(API_URL, headers=headers, json=payloads)
279
+ # with open('audio_story.mp3', 'wb') as file:
280
+ # file.write(response.content)
281
+ print("Text-to-speech conversion completed.")
282
+
283
 
284
 
285
  def get_sketch_prompt(mask: Image.Image):
 
342
  origin_image_input = image_input
343
 
344
  fake_click_index = (int((boxes[0][0] + boxes[0][2]) / 2), int((boxes[0][1] + boxes[0][3]) / 2))
345
+ image_input = create_bubble_frame(image_input, "", fake_click_index, input_mask)
346
 
347
  yield state, state, image_input
348
 
 
435
  visual_chatgpt = gr.State(None)
436
  original_size = gr.State(None)
437
  input_size = gr.State(None)
438
+ generated_caption = gr.State("")
439
  # img_caption = gr.State(None)
440
  aux_state = gr.State([])
441
 
 
463
  with gr.Row(scale=0.4):
464
  clear_button_click = gr.Button(value="Clear Clicks", interactive=True)
465
  clear_button_image = gr.Button(value="Clear Image", interactive=True)
466
+ submit_button_click=gr.Button(value="Submit", interactive=True)
467
  with gr.Tab("Trajectory (beta)"):
468
  sketcher_input = ImageSketcher(type="pil", interactive=True, brush_radius=20,
469
  elem_id="image_sketcher")
 
587
  )
588
  clear_button_text.click(clear_chat_memory, inputs=[visual_chatgpt])
589
 
 
 
590
  image_input.clear(
591
  lambda: (None, [], [], [[], [], []], "", "", ""),
592
  [],
 
614
  [chatbot, state, origin_image, click_state, image_input, sketcher_input,
615
  image_embedding, original_size, input_size])
616
  example_image.change(clear_chat_memory, inputs=[visual_chatgpt])
617
+
618
+ image_input.select(
619
+ inference_click,
620
+ inputs=[
621
+ origin_image, point_prompt, click_mode, enable_wiki, language, sentiment, factuality, length,
622
+ image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt
623
+ ],
624
+ outputs=[chatbot, state, click_state, image_input, generated_caption],
625
+ show_progress=False, queue=True
626
+ )
627
+
628
+ submit_button_click.click(
629
+ submit_caption,
630
+ inputs=[image_input, state, generated_caption],
631
+ outputs=[chatbot,state,image_input],
632
+ show_progress=True, queue=True
633
+ )
634
+
635
+
636
 
637
  submit_button_sketcher.click(
638
  inference_traject,