vwxyzjn commited on
Commit
603bc51
·
1 Parent(s): 634833f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -12
app.py CHANGED
@@ -94,6 +94,21 @@ Q: Who won Super Bowl XX?
94
  A: Chicago Bears
95
  A2: <request><Wiki>Super Bowl XX<call>Super Bowl XX was an American football game between the National Football Conference (NFC) champion Chicago Bears and the American Football Conference (AFC) champion New England Patriots to decide the National Football League (NFL) champion for the 1985 season. The Bears defeated the Patriots by the score of 46–10, capturing their first NFL championship (and Chicago's first overall sports victory) since 1963, three years prior to the birth of the Super Bowl. Super Bowl XX was played on January 26, 1986 at the Louisiana Superdome in New Orleans.<response>
96
  Result=Chicago Bears<submit>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  """
98
  ],
99
  }
@@ -146,6 +161,11 @@ def generate(
146
  stop_sequences=["<call>"]
147
  )
148
  generation_still_running = True
 
 
 
 
 
149
  while generation_still_running:
150
  try:
151
  stream = client.generate_stream(system_prompt + prompt, **generate_kwargs)
@@ -153,16 +173,19 @@ def generate(
153
 
154
  # call env phase
155
  output = system_prompt + prompt
156
- previous_token = ""
 
 
 
 
157
  for response in stream:
158
  if response.token.text == "<|endoftext|>":
159
  return output
160
  else:
161
  output += response.token.text
162
- previous_token = response.token.text
163
- # text env logic:
164
- tool, query = parse_tool_call(output[len(system_prompt + prompt):])
165
- print("tool", tool, query)
166
  if tool is not None and query is not None:
167
  if tool not in tools:
168
  response = f"Unknown tool {tool}."
@@ -171,22 +194,48 @@ def generate(
171
  output += response + "<response>"
172
  except Exception as error:
173
  response = f"Tool error: {str(error)}"
174
- yield output[len(system_prompt + prompt):]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
  call_output = copy.deepcopy(output)
177
  # response phase
178
  generate_kwargs["stop_sequences"] = ["<submit>"]
179
  stream = client.generate_stream(output, **generate_kwargs)
180
- previous_token = ""
181
  for response in stream:
182
  if response.token.text == "<|endoftext|>":
183
  return output
184
  else:
185
  output += response.token.text
186
- previous_token = response.token.text
187
- yield output[len(system_prompt + prompt):]
188
-
189
- return output
 
 
 
 
 
 
 
 
 
 
190
  except Exception as e:
191
  if "loading" in str(e):
192
  gr.Warning("waiting for model to load... (this could take up to 20 minutes, after which things are much faster)")
@@ -264,7 +313,11 @@ with gr.Blocks(theme=theme, analytics_enabled=False, css=css) as demo:
264
  elem_id="q-input",
265
  )
266
  submit = gr.Button("Generate", variant="primary")
267
- output = gr.Code(elem_id="q-output", lines=30, label="Output")
 
 
 
 
268
  with gr.Row():
269
  with gr.Column():
270
  with gr.Accordion("Advanced settings", open=False):
 
94
  A: Chicago Bears
95
  A2: <request><Wiki>Super Bowl XX<call>Super Bowl XX was an American football game between the National Football Conference (NFC) champion Chicago Bears and the American Football Conference (AFC) champion New England Patriots to decide the National Football League (NFL) champion for the 1985 season. The Bears defeated the Patriots by the score of 46–10, capturing their first NFL championship (and Chicago's first overall sports victory) since 1963, three years prior to the birth of the Super Bowl. Super Bowl XX was played on January 26, 1986 at the Louisiana Superdome in New Orleans.<response>
96
  Result=Chicago Bears<submit>
97
+ """
98
+ ],
99
+ "StarCoderBase TriviaQA2": [
100
+ Client(
101
+ "https://api-inference.huggingface.co/models/vwxyzjn/starcoderbase-triviaqa",
102
+ headers={"Authorization": f"Bearer {HF_TOKEN}"},
103
+ ),
104
+ {"Wiki": tool_fn},
105
+ """\
106
+ Answer the following question:
107
+
108
+ Q: In which branch of the arts is Patricia Neary famous?
109
+ A: Ballets
110
+ A2: <request><Wiki>Patricia Neary<call>Patricia Neary (born October 27, 1942) is an American ballerina, choreographer and ballet director, who has been particularly active in Switzerland. She has also been a highly successful ambassador for the Balanchine Trust, bringing George Balanchine's ballets to 60 cities around the globe.<response>
111
+ Result=Ballets<submit>
112
  """
113
  ],
114
  }
 
161
  stop_sequences=["<call>"]
162
  )
163
  generation_still_running = True
164
+ request_idx = -1
165
+ call_idx = -1
166
+ response_idx = -1
167
+ submit_idx = -1
168
+
169
  while generation_still_running:
170
  try:
171
  stream = client.generate_stream(system_prompt + prompt, **generate_kwargs)
 
173
 
174
  # call env phase
175
  output = system_prompt + prompt
176
+ generation_start_idx = len(output)
177
+ highlighted_output = [
178
+ (prompt, ""),
179
+ ]
180
+ yield highlighted_output
181
  for response in stream:
182
  if response.token.text == "<|endoftext|>":
183
  return output
184
  else:
185
  output += response.token.text
186
+ print(response.token.text)
187
+ tool, query = parse_tool_call(output[generation_start_idx:])
188
+ # print("tool", tool, query)
 
189
  if tool is not None and query is not None:
190
  if tool not in tools:
191
  response = f"Unknown tool {tool}."
 
194
  output += response + "<response>"
195
  except Exception as error:
196
  response = f"Tool error: {str(error)}"
197
+
198
+ if request_idx == -1:
199
+ request_idx = output[generation_start_idx:].find("<request>")
200
+ if call_idx == -1:
201
+ call_idx = output[generation_start_idx:].find("<call>")
202
+ if response_idx == -1:
203
+ response_idx = output[generation_start_idx:].find("<response>")
204
+
205
+ # if `<request>` is in the output, highlight it, if `<call>` is in the output, highlight it
206
+ # print(generation_start_idx, request_idx, call_idx, response_idx)
207
+ highlighted_output = [
208
+ (prompt, ""),
209
+ (output[generation_start_idx:generation_start_idx+request_idx], ""),
210
+ (output[generation_start_idx+request_idx:generation_start_idx+call_idx], "request"),
211
+ (output[generation_start_idx+call_idx:-1], "call"),
212
+ ]
213
+ # print(highlighted_output, output)
214
+ yield highlighted_output
215
 
216
  call_output = copy.deepcopy(output)
217
  # response phase
218
  generate_kwargs["stop_sequences"] = ["<submit>"]
219
  stream = client.generate_stream(output, **generate_kwargs)
 
220
  for response in stream:
221
  if response.token.text == "<|endoftext|>":
222
  return output
223
  else:
224
  output += response.token.text
225
+ if submit_idx == -1:
226
+ submit_idx = output[generation_start_idx:].find("<submit>")
227
+ # print(generation_start_idx, request_idx, call_idx, response_idx, submit_idx)
228
+ highlighted_output = [
229
+ (prompt, ""),
230
+ (output[generation_start_idx:generation_start_idx+request_idx], ""),
231
+ (output[generation_start_idx+request_idx:generation_start_idx+call_idx], "request"),
232
+ (output[generation_start_idx+call_idx:generation_start_idx+response_idx], "call"),
233
+ (output[generation_start_idx+response_idx:-1], "submit"),
234
+ ]
235
+ # print(highlighted_output, output)
236
+ yield highlighted_output
237
+
238
+ return highlighted_output
239
  except Exception as e:
240
  if "loading" in str(e):
241
  gr.Warning("waiting for model to load... (this could take up to 20 minutes, after which things are much faster)")
 
313
  elem_id="q-input",
314
  )
315
  submit = gr.Button("Generate", variant="primary")
316
+ # output = gr.Code(elem_id="q-output", lines=30, label="Output")
317
+ output = gr.HighlightedText(
318
+ label="Output",
319
+ color_map={"query": "red", "call": "green", "response": "blue", "submit": "yellow", "model": "pink"},
320
+ )
321
  with gr.Row():
322
  with gr.Column():
323
  with gr.Accordion("Advanced settings", open=False):