vwxyzjn commited on
Commit
3550d0c
·
1 Parent(s): 56ae496

bug fix: actually use interpreter tool

Browse files
Files changed (1) hide show
  1. app.py +29 -37
app.py CHANGED
@@ -10,6 +10,7 @@ from share_btn import community_icon_html, loading_icon_html, share_js, share_bt
10
 
11
 
12
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
 
13
  print(HF_TOKEN)
14
 
15
  FIM_PREFIX = "<fim_prefix>"
@@ -90,7 +91,7 @@ Result=Chicago Bears<submit>
90
  "https://api-inference.huggingface.co/models/lvwerra/starcoderbase-gsm8k",
91
  headers={"Authorization": f"Bearer {HF_TOKEN}"},
92
  ),
93
- {"Wiki": load_tool("lvwerra/python-interpreter")},
94
  """\
95
  Example of using a Python API to solve math questions.
96
 
@@ -166,6 +167,7 @@ def generate(
166
  response_idx = -1
167
  submit_idx = -1
168
 
 
169
  while generation_still_running:
170
  try:
171
  stream = client.generate_stream(system_prompt + prompt, **generate_kwargs)
@@ -179,18 +181,21 @@ def generate(
179
  ]
180
  yield highlighted_output
181
  for response in stream:
 
182
  if response.token.text == "<|endoftext|>":
183
  return output
184
  else:
185
  output += response.token.text
186
  tool, query = parse_tool_call(output[generation_start_idx:])
187
- # print("tool", tool, query)
188
  if tool is not None and query is not None:
 
189
  if tool not in tools:
190
  response = f"Unknown tool {tool}."
191
  try:
192
  response = tools[tool](query)
193
  output += response + "<response>"
 
194
  except Exception as error:
195
  response = f"Tool error: {str(error)}"
196
 
@@ -209,30 +214,30 @@ def generate(
209
  (output[generation_start_idx+request_idx:generation_start_idx+call_idx], "request"),
210
  (output[generation_start_idx+call_idx:-1], "call"),
211
  ]
212
- # print(highlighted_output, output)
213
  yield highlighted_output
214
 
215
- call_output = copy.deepcopy(output)
216
- # response phase
217
- generate_kwargs["stop_sequences"] = ["<submit>"]
218
- stream = client.generate_stream(output, **generate_kwargs)
219
- for response in stream:
220
- if response.token.text == "<|endoftext|>":
221
- return output
222
- else:
223
- output += response.token.text
224
- if submit_idx == -1:
225
- submit_idx = output[generation_start_idx:].find("<submit>")
226
- # print(generation_start_idx, request_idx, call_idx, response_idx, submit_idx)
227
- highlighted_output = [
228
- (prompt, ""),
229
- (output[generation_start_idx:generation_start_idx+request_idx], ""),
230
- (output[generation_start_idx+request_idx:generation_start_idx+call_idx], "request"),
231
- (output[generation_start_idx+call_idx:generation_start_idx+response_idx], "call"),
232
- (output[generation_start_idx+response_idx:-1], "submit"),
233
- ]
234
- # print(highlighted_output, output)
235
- yield highlighted_output
236
 
237
  return highlighted_output
238
  except Exception as e:
@@ -389,16 +394,3 @@ with gr.Blocks(theme=theme, analytics_enabled=False, css=css) as demo:
389
  )
390
  share_button.click(None, [], [], _js=share_js)
391
  demo.queue(concurrency_count=16).launch(debug=True)
392
-
393
-
394
- """
395
- Answer the following question:
396
- Q: In which branch of the arts is Patricia Neary famous?
397
- A: Ballets
398
- A2: <request><Wiki>Patricia Neary<call>Patricia Neary (born October 27, 1942) is an American ballerina, choreographer and ballet director, who has been particularly active in Switzerland. She has also been a highly successful ambassador for the Balanchine Trust, bringing George Balanchine's ballets to 60 cities around the globe.<response>
399
- Result=Ballets<submit>
400
- Q: Who won Super Bowl XX?
401
- A: Chicago Bears
402
- A2: <request><Wiki>Super Bowl XX<call>Super Bowl XX was an American football game between the National Football Conference (NFC) champion Chicago Bears and the American Football Conference (AFC) champion New England Patriots to decide the National Football League (NFL) champion for the 1985 season. The Bears defeated the Patriots by the score of 46–10, capturing their first NFL championship (and Chicago's first overall sports victory) since 1963, three years prior to the birth of the Super Bowl. Super Bowl XX was played on January 26, 1986 at the Louisiana Superdome in New Orleans.<response>
403
- Result=Chicago Bears<submit>
404
- Q: In what state is Philadelphia located?"""
 
10
 
11
 
12
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
13
+ os.environ["HF_ALLOW_CODE_EVAL"] = "1"
14
  print(HF_TOKEN)
15
 
16
  FIM_PREFIX = "<fim_prefix>"
 
91
  "https://api-inference.huggingface.co/models/lvwerra/starcoderbase-gsm8k",
92
  headers={"Authorization": f"Bearer {HF_TOKEN}"},
93
  ),
94
+ {"PythonInterpreter": load_tool("lvwerra/python-interpreter")},
95
  """\
96
  Example of using a Python API to solve math questions.
97
 
 
167
  response_idx = -1
168
  submit_idx = -1
169
 
170
+ i = 0
171
  while generation_still_running:
172
  try:
173
  stream = client.generate_stream(system_prompt + prompt, **generate_kwargs)
 
181
  ]
182
  yield highlighted_output
183
  for response in stream:
184
+ i += 1
185
  if response.token.text == "<|endoftext|>":
186
  return output
187
  else:
188
  output += response.token.text
189
  tool, query = parse_tool_call(output[generation_start_idx:])
190
+
191
  if tool is not None and query is not None:
192
+ # print("=====tool", i, tool, response, output)
193
  if tool not in tools:
194
  response = f"Unknown tool {tool}."
195
  try:
196
  response = tools[tool](query)
197
  output += response + "<response>"
198
+
199
  except Exception as error:
200
  response = f"Tool error: {str(error)}"
201
 
 
214
  (output[generation_start_idx+request_idx:generation_start_idx+call_idx], "request"),
215
  (output[generation_start_idx+call_idx:-1], "call"),
216
  ]
217
+ print(i, highlighted_output, output)
218
  yield highlighted_output
219
 
220
+ call_output = copy.deepcopy(output)
221
+ # response phase
222
+ generate_kwargs["stop_sequences"] = ["<submit>"]
223
+ stream = client.generate_stream(output, **generate_kwargs)
224
+ for response in stream:
225
+ if response.token.text == "<|endoftext|>":
226
+ return output
227
+ else:
228
+ output += response.token.text
229
+ if submit_idx == -1:
230
+ submit_idx = output[generation_start_idx:].find("<submit>")
231
+ # print(generation_start_idx, request_idx, call_idx, response_idx, submit_idx)
232
+ highlighted_output = [
233
+ (prompt, ""),
234
+ (output[generation_start_idx:generation_start_idx+request_idx], ""),
235
+ (output[generation_start_idx+request_idx:generation_start_idx+call_idx], "request"),
236
+ (output[generation_start_idx+call_idx:generation_start_idx+response_idx], "call"),
237
+ (output[generation_start_idx+response_idx:-1], "submit"),
238
+ ]
239
+ print(highlighted_output, output)
240
+ yield highlighted_output
241
 
242
  return highlighted_output
243
  except Exception as e:
 
394
  )
395
  share_button.click(None, [], [], _js=share_js)
396
  demo.queue(concurrency_count=16).launch(debug=True)