merve HF staff commited on
Commit
95dbe7e
1 Parent(s): 35dad4a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -50
app.py CHANGED
@@ -1,11 +1,12 @@
1
  import os
2
  import string
3
-
4
  import gradio as gr
5
  import PIL.Image
6
  import torch
7
  from transformers import BitsAndBytesConfig, pipeline
8
  import re
 
9
 
10
  DESCRIPTION = "# LLaVA 🌋"
11
 
@@ -21,48 +22,79 @@ pipe = pipeline("image-to-text", model=model_id, model_kwargs={"quantization_con
21
  def extract_response_pairs(text):
22
  turns = re.split(r'(USER:|ASSISTANT:)', text)[1:]
23
  turns = [turn.strip() for turn in turns if turn.strip()]
24
- print(turns[1::2])
25
  conv_list = []
26
  for i in range(0, len(turns[1::2]), 2):
27
  if i + 1 < len(turns[1::2]):
28
- conv_list.append((turns[1::2][i].lstrip(":"), turns[1::2][i + 1].lstrip(":")))
 
29
  return conv_list
30
 
31
 
32
- def postprocess_output(output: str) -> str:
33
- if output and output[-1] not in string.punctuation:
34
- output += "."
35
- return output
36
 
 
 
 
37
 
 
 
 
 
 
 
 
38
 
39
- def chat(image, text, temperature, length_penalty,
40
- repetition_penalty, max_length, min_length, top_p,
41
- history_chat):
42
-
43
- prompt = " ".join(history_chat) + f"USER: <image>\n{text}\nASSISTANT:"
44
-
45
- outputs = pipe(image, prompt=prompt,
46
  generate_kwargs={"temperature":temperature,
47
  "length_penalty":length_penalty,
48
  "repetition_penalty":repetition_penalty,
49
  "max_length":max_length,
50
  "min_length":min_length,
51
  "top_p":top_p})
52
-
 
53
 
54
- history_chat.append(outputs[0]["generated_text"])
55
 
56
 
57
- chat_val = extract_response_pairs(" ".join(history_chat))
58
- return chat_val, history_chat
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
 
61
  css = """
62
  #mkd {
63
- height: 500px;
64
- overflow: auto;
65
- border: 1px solid #ccc;
66
  }
67
  """
68
  with gr.Blocks(css="style.css") as demo:
@@ -74,16 +106,12 @@ with gr.Blocks(css="style.css") as demo:
74
  chatbot = gr.Chatbot(label="Chat", show_label=False)
75
  gr.Markdown("Input image and text and start chatting 👇")
76
  with gr.Row():
77
-
78
  image = gr.Image(type="pil")
79
  text_input = gr.Text(label="Chat Input", show_label=False, max_lines=3, container=False)
80
-
81
-
82
-
83
  history_chat = gr.State(value=[])
84
- with gr.Row():
85
- clear_chat_button = gr.Button("Clear")
86
- chat_button = gr.Button("Submit", variant="primary")
87
  with gr.Accordion(label="Advanced settings", open=False):
88
  temperature = gr.Slider(
89
  label="Temperature",
@@ -135,18 +163,7 @@ with gr.Blocks(css="style.css") as demo:
135
  chatbot,
136
  history_chat
137
  ]
138
- chat_button.click(fn=chat, inputs=[image,
139
- text_input,
140
- temperature,
141
- length_penalty,
142
- repetition_penalty,
143
- max_length,
144
- min_length,
145
- top_p,
146
- history_chat],
147
- outputs=chat_output,
148
- api_name="Chat",
149
- )
150
 
151
  chat_inputs = [
152
  image,
@@ -159,15 +176,31 @@ with gr.Blocks(css="style.css") as demo:
159
  top_p,
160
  history_chat
161
  ]
 
 
 
 
 
 
 
 
 
 
 
 
162
  text_input.submit(
163
- fn=chat,
164
- inputs=chat_inputs,
165
- outputs=chat_output
166
- ).success(
167
- fn=lambda: "",
168
- outputs=chat_inputs,
169
- queue=False,
170
- api_name=False,
 
 
 
 
171
  )
172
  clear_chat_button.click(
173
  fn=lambda: ([], []),
@@ -187,7 +220,6 @@ with gr.Blocks(css="style.css") as demo:
187
  history_chat
188
  ],
189
  queue=False)
190
-
191
  examples = [["./examples/baklava.png", "How to make this pastry?"],["./examples/bee.png","Describe this image."]]
192
  gr.Examples(examples=examples, inputs=[image, text_input, chat_inputs])
193
 
@@ -195,4 +227,4 @@ with gr.Blocks(css="style.css") as demo:
195
 
196
 
197
  if __name__ == "__main__":
198
- demo.queue(max_size=10).launch()
 
1
  import os
2
  import string
3
+ import copy
4
  import gradio as gr
5
  import PIL.Image
6
  import torch
7
  from transformers import BitsAndBytesConfig, pipeline
8
  import re
9
+ import time
10
 
11
  DESCRIPTION = "# LLaVA 🌋"
12
 
 
22
  def extract_response_pairs(text):
23
  turns = re.split(r'(USER:|ASSISTANT:)', text)[1:]
24
  turns = [turn.strip() for turn in turns if turn.strip()]
 
25
  conv_list = []
26
  for i in range(0, len(turns[1::2]), 2):
27
  if i + 1 < len(turns[1::2]):
28
+ conv_list.append([turns[1::2][i].lstrip(":"), turns[1::2][i + 1].lstrip(":")])
29
+
30
  return conv_list
31
 
32
 
 
 
 
 
33
 
34
+ def add_text(history, text):
35
+ history = history.append([text, None])
36
+ return history, text
37
 
38
+ def infer(image, prompt,
39
+ temperature,
40
+ length_penalty,
41
+ repetition_penalty,
42
+ max_length,
43
+ min_length,
44
+ top_p):
45
 
46
+ outputs = pipe(images=image, prompt=prompt,
 
 
 
 
 
 
47
  generate_kwargs={"temperature":temperature,
48
  "length_penalty":length_penalty,
49
  "repetition_penalty":repetition_penalty,
50
  "max_length":max_length,
51
  "min_length":min_length,
52
  "top_p":top_p})
53
+ inference_output = outputs[0]["generated_text"]
54
+ return inference_output
55
 
 
56
 
57
 
58
+ def bot(history_chat, text_input, image,
59
+ temperature,
60
+ length_penalty,
61
+ repetition_penalty,
62
+ max_length,
63
+ min_length,
64
+ top_p):
65
+ chat_history = " ".join(history_chat) # history as a str to be passed to model
66
+ chat_history = chat_history + f"USER: <image>\n{text_input}\nASSISTANT:" # add text input for prompting
67
+
68
+
69
+ inference_result = infer(image, chat_history,
70
+ temperature,
71
+ length_penalty,
72
+ repetition_penalty,
73
+ max_length,
74
+ min_length,
75
+ top_p)
76
+ # return inference and parse for new history
77
+ chat_val = extract_response_pairs(inference_result)
78
+
79
+ # create history list for yielding the last inference response
80
+ chat_state_list = copy.deepcopy(chat_val)
81
+ chat_state_list[-1][1] = "" # empty last response
82
+
83
+ # add characters iteratively
84
+
85
+ for character in chat_val[-1][1]:
86
+ chat_state_list[-1][1] += character
87
+ time.sleep(0.05)
88
+ # yield history but with last response being streamed
89
+ print(chat_state_list)
90
+ yield chat_state_list
91
 
92
 
93
  css = """
94
  #mkd {
95
+ height: 500px;
96
+ overflow: auto;
97
+ border: 1px solid #ccc;
98
  }
99
  """
100
  with gr.Blocks(css="style.css") as demo:
 
106
  chatbot = gr.Chatbot(label="Chat", show_label=False)
107
  gr.Markdown("Input image and text and start chatting 👇")
108
  with gr.Row():
109
+
110
  image = gr.Image(type="pil")
111
  text_input = gr.Text(label="Chat Input", show_label=False, max_lines=3, container=False)
112
+
 
 
113
  history_chat = gr.State(value=[])
114
+
 
 
115
  with gr.Accordion(label="Advanced settings", open=False):
116
  temperature = gr.Slider(
117
  label="Temperature",
 
163
  chatbot,
164
  history_chat
165
  ]
166
+
 
 
 
 
 
 
 
 
 
 
 
167
 
168
  chat_inputs = [
169
  image,
 
176
  top_p,
177
  history_chat
178
  ]
179
+ with gr.Row():
180
+ clear_chat_button = gr.Button("Clear")
181
+ chat_button = gr.Button("Submit", variant="primary")
182
+
183
+ chat_button.click(add_text, [chatbot, text_input], [chatbot, text_input]).then(bot, [chatbot, text_input,
184
+ image, temperature,
185
+ length_penalty,
186
+ repetition_penalty,
187
+ max_length,
188
+ min_length,
189
+ top_p], chatbot)
190
+
191
  text_input.submit(
192
+ add_text,
193
+ [chatbot, text_input],
194
+ [chatbot, text_input]
195
+ ).then(
196
+ fn=bot,
197
+ inputs=[chatbot, text_input, image, temperature,
198
+ length_penalty,
199
+ repetition_penalty,
200
+ max_length,
201
+ min_length,
202
+ top_p],
203
+ outputs=chatbot
204
  )
205
  clear_chat_button.click(
206
  fn=lambda: ([], []),
 
220
  history_chat
221
  ],
222
  queue=False)
 
223
  examples = [["./examples/baklava.png", "How to make this pastry?"],["./examples/bee.png","Describe this image."]]
224
  gr.Examples(examples=examples, inputs=[image, text_input, chat_inputs])
225
 
 
227
 
228
 
229
  if __name__ == "__main__":
230
+ demo.queue(max_size=10).launch(debug=True)