merve HF staff commited on
Commit
3d139ce
1 Parent(s): 47173ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -87
app.py CHANGED
@@ -1,7 +1,3 @@
1
- #!/usr/bin/env python
2
-
3
- from __future__ import annotations
4
-
5
  import os
6
  import string
7
 
@@ -21,9 +17,12 @@ quantization_config = BitsAndBytesConfig(
21
  pipe = pipeline("image-to-text", model=model_id, model_kwargs={"quantization_config": quantization_config})
22
 
23
 
 
 
24
  def extract_response_pairs(text):
25
  pattern = re.compile(r'(USER:.*?)ASSISTANT:(.*?)(?:$|USER:)', re.DOTALL)
26
  matches = pattern.findall(text)
 
27
 
28
  pairs = [(user.strip(), assistant.strip()) for user, assistant in matches]
29
 
@@ -37,26 +36,19 @@ def postprocess_output(output: str) -> str:
37
 
38
 
39
 
40
- def chat(image, text, temperature, length_penalty,
41
- repetition_penalty, max_length, min_length, num_beams, top_p,
42
- history_chat):
43
 
44
- prompt = " ".join(history_chat)
45
- prompt = f"USER: <image>\n{text}\nASSISTANT:"
46
 
47
  outputs = pipe(image, prompt=prompt,
48
- generate_kwargs={"temperature":temperature,
49
- "length_penalty":length_penalty,
50
- "repetition_penalty":repetition_penalty,
51
- "max_length":max_length,
52
- "min_length":min_length,
53
- "num_beams":num_beams,
54
- "top_p":top_p})
55
 
56
- output = postprocess_output(outputs[0]["generated_text"])
57
- history_chat.append(output)
58
 
59
  chat_val = extract_response_pairs(" ".join(history_chat))
 
60
  return chat_val, history_chat
61
 
62
 
@@ -69,89 +61,31 @@ css = """
69
  """
70
  with gr.Blocks(css="style.css") as demo:
71
  gr.Markdown(DESCRIPTION)
72
- gr.Markdown("**LLaVA, one of the greatest multimodal chat models is now available in transformers with 4-bit quantization! ⚡️ **")
73
- gr.Markdown("**Try it in this demo 🤗 **")
74
-
75
  chatbot = gr.Chatbot(label="Chat", show_label=False)
76
- gr.Markdown("Input image and text and start chatting 👇")
77
  with gr.Row():
78
-
79
  image = gr.Image(type="pil")
80
- text_input = gr.Text(label="Chat Input", show_label=False, max_lines=3, container=False)
81
-
82
-
83
 
84
  history_chat = gr.State(value=[])
85
  with gr.Row():
86
  clear_chat_button = gr.Button("Clear")
87
  chat_button = gr.Button("Submit", variant="primary")
88
  with gr.Accordion(label="Advanced settings", open=False):
89
- temperature = gr.Slider(
90
- label="Temperature",
91
- info="Used with nucleus sampling.",
92
- minimum=0.5,
93
- maximum=1.0,
94
- step=0.1,
95
- value=1.0,
96
- )
97
- length_penalty = gr.Slider(
98
- label="Length Penalty",
99
- info="Set to larger for longer sequence, used with beam search.",
100
- minimum=-1.0,
101
- maximum=2.0,
102
- step=0.2,
103
- value=1.0,
104
- )
105
- repetition_penalty = gr.Slider(
106
- label="Repetition Penalty",
107
- info="Larger value prevents repetition.",
108
- minimum=1.0,
109
- maximum=5.0,
110
- step=0.5,
111
- value=1.5,
112
- )
113
  max_length = gr.Slider(
114
  label="Max Length",
115
  minimum=1,
116
- maximum=512,
117
  step=1,
118
- value=50,
119
- )
120
- min_length = gr.Slider(
121
- label="Minimum Length",
122
- minimum=1,
123
- maximum=100,
124
- step=1,
125
- value=1,
126
- )
127
- num_beams = gr.Slider(
128
- label="Number of Beams",
129
- minimum=1,
130
- maximum=10,
131
- step=1,
132
- value=5,
133
- )
134
- top_p = gr.Slider(
135
- label="Top P",
136
- info="Used with nucleus sampling.",
137
- minimum=0.5,
138
- maximum=1.0,
139
- step=0.1,
140
- value=0.9,
141
  )
 
142
  chat_output = [
143
  chatbot,
144
  history_chat
145
  ]
146
  chat_button.click(fn=chat, inputs=[image,
147
  text_input,
148
- temperature,
149
- length_penalty,
150
- repetition_penalty,
151
  max_length,
152
- min_length,
153
- num_beams,
154
- top_p,
155
  history_chat],
156
  outputs=chat_output,
157
  api_name="Chat",
@@ -160,13 +94,7 @@ with gr.Blocks(css="style.css") as demo:
160
  chat_inputs = [
161
  image,
162
  text_input,
163
- temperature,
164
- length_penalty,
165
- repetition_penalty,
166
  max_length,
167
- min_length,
168
- num_beams,
169
- top_p,
170
  history_chat
171
  ]
172
  text_input.submit(
@@ -201,4 +129,4 @@ with gr.Blocks(css="style.css") as demo:
201
 
202
 
203
  if __name__ == "__main__":
204
- demo.queue(max_size=10).launch()
 
 
 
 
 
1
  import os
2
  import string
3
 
 
17
  pipe = pipeline("image-to-text", model=model_id, model_kwargs={"quantization_config": quantization_config})
18
 
19
 
20
+ DESCRIPTION = "LLaVA is now available in transformers!"
21
+
22
  def extract_response_pairs(text):
23
  pattern = re.compile(r'(USER:.*?)ASSISTANT:(.*?)(?:$|USER:)', re.DOTALL)
24
  matches = pattern.findall(text)
25
+ print(matches)
26
 
27
  pairs = [(user.strip(), assistant.strip()) for user, assistant in matches]
28
 
 
36
 
37
 
38
 
39
+ def chat(image, text, max_length, history_chat):
 
 
40
 
41
+ prompt = " ".join(history_chat) + f"USER: <image>\n{text}\nASSISTANT:"
 
42
 
43
  outputs = pipe(image, prompt=prompt,
44
+ generate_kwargs={
45
+ "max_length":max_length})
 
 
 
 
 
46
 
47
+ #output = postprocess_output(outputs[0]["generated_text"])
48
+ history_chat.append(outputs[0]["generated_text"])
49
 
50
  chat_val = extract_response_pairs(" ".join(history_chat))
51
+
52
  return chat_val, history_chat
53
 
54
 
 
61
  """
62
  with gr.Blocks(css="style.css") as demo:
63
  gr.Markdown(DESCRIPTION)
 
 
 
64
  chatbot = gr.Chatbot(label="Chat", show_label=False)
 
65
  with gr.Row():
 
66
  image = gr.Image(type="pil")
67
+ text_input = gr.Text(label="Chat Input", show_label=False, max_lines=1, container=False)
 
 
68
 
69
  history_chat = gr.State(value=[])
70
  with gr.Row():
71
  clear_chat_button = gr.Button("Clear")
72
  chat_button = gr.Button("Submit", variant="primary")
73
  with gr.Accordion(label="Advanced settings", open=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  max_length = gr.Slider(
75
  label="Max Length",
76
  minimum=1,
77
+ maximum=200,
78
  step=1,
79
+ value=100,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  )
81
+
82
  chat_output = [
83
  chatbot,
84
  history_chat
85
  ]
86
  chat_button.click(fn=chat, inputs=[image,
87
  text_input,
 
 
 
88
  max_length,
 
 
 
89
  history_chat],
90
  outputs=chat_output,
91
  api_name="Chat",
 
94
  chat_inputs = [
95
  image,
96
  text_input,
 
 
 
97
  max_length,
 
 
 
98
  history_chat
99
  ]
100
  text_input.submit(
 
129
 
130
 
131
  if __name__ == "__main__":
132
+ demo.queue(max_size=10).launch(debug=True)