shlomihod commited on
Commit
ccf18cf
·
1 Parent(s): 19fa3f3

update versions of api

Browse files
Files changed (1) hide show
  1. app.py +60 -36
app.py CHANGED
@@ -160,7 +160,17 @@ def prepare_huggingface_generation_config(generation_config):
160
  else:
161
  assert generation_config["do_sample"]
162
 
163
- return generation_config
 
 
 
 
 
 
 
 
 
 
164
 
165
 
166
  def escape_markdown(text):
@@ -195,29 +205,28 @@ def build_api_call_function(model):
195
  global HOW_OPENAI_INITIATED
196
 
197
  if model.startswith("openai") or model.startswith("azure"):
198
- import openai
199
-
200
  provider, model = model.split("/")
201
 
202
  if provider == "openai":
203
- # TODO: how to avoid hardcoding this?
204
- # https://github.com/openai/openai-python/blob/b82a3f7e4c462a8a10fa445193301a3cefef9a4a/openai/__init__.py#L49
205
- openai.api_type = "open_ai"
206
- openai.api_base = "https://api.openai.com/v1"
207
- openai.api_version = None
208
- openai.api_key = OPENAI_API_KEY
209
- engine = None
210
 
211
  elif provider == "azure":
212
- openai.api_type = "azure"
213
- openai.api_base = AZURE_OPENAI_ENDPOINT
214
- openai.api_version = "2023-05-15"
215
- openai.api_key = AZURE_OPENAI_KEY
216
- engine = AZURE_DEPLOYMENT_NAME
217
 
218
- openai_models = {model_obj["id"] for model_obj in openai.Model.list()["data"]}
 
 
 
 
 
 
 
 
 
 
219
  assert model in openai_models
220
- LOGGER.info(f"API URL {openai.api_base}")
221
 
222
  @retry(
223
  wait=wait_random_exponential(min=RETRY_MIN_WAIT, max=RETRY_MAX_WAIT),
@@ -234,20 +243,18 @@ def build_api_call_function(model):
234
  max_tokens = generation_config["max_new_tokens"]
235
 
236
  if model.startswith("gpt") and "instruct" not in model:
237
- response = await openai.ChatCompletion.acreate(
238
- engine=engine,
239
  model=model,
240
  messages=[{"role": "user", "content": prompt}],
241
  temperature=temperature,
242
  top_p=top_p,
243
  max_tokens=max_tokens,
244
  )
245
- assert response["choices"][0]["message"]["role"] == "assistant"
246
- output = response["choices"][0]["message"]["content"]
247
 
248
  else:
249
- response = await openai.Completion.acreate(
250
- engine=engine,
251
  model=model,
252
  prompt=prompt,
253
  temperature=temperature,
@@ -344,8 +351,11 @@ def build_api_call_function(model):
344
  )
345
 
346
  async def api_call_function(prompt, generation_config):
347
- generation_config = prepare_huggingface_generation_config(generation_config)
 
 
348
 
 
349
  output = pipe(prompt, return_text=True, **generation_config)[0][
350
  "generated_text"
351
  ]
@@ -365,22 +375,32 @@ def build_api_call_function(model):
365
  async def api_call_function(prompt, generation_config):
366
  hf_client = AsyncInferenceClient(token=HF_TOKEN, model=model)
367
 
368
- generation_config = prepare_huggingface_generation_config(generation_config)
369
-
370
- response = await hf_client.text_generation(
371
- prompt, stream=False, details=True, **generation_config
372
  )
373
- LOGGER.info(response)
374
 
375
- length = len(response.details.prefill) + len(response.details.tokens)
 
 
 
 
 
 
 
 
 
 
 
376
 
377
- output = response.generated_text
 
 
 
 
378
 
379
- # response = st.session_state.client.post(json={"inputs": prompt})
380
- # output = response.json()[0]["generated_text"]
381
- # output = st.session_state.client.conversational(prompt, model=model)
382
- # output = output if "https" in st.session_state.client.model else output[len(prompt) :]
383
 
 
384
  # Remove stop sequences from the output
385
  # Inspired by
386
  # https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py
@@ -836,7 +856,11 @@ def main():
836
  st.caption("Dataset")
837
  st.write(data_card)
838
  try:
839
- model_card = model_info(model).cardData
 
 
 
 
840
  except (HFValidationError, RepositoryNotFoundError):
841
  pass
842
  else:
 
160
  else:
161
  assert generation_config["do_sample"]
162
 
163
+ # TODO: refactor this part
164
+ if generation_config["is_chat"]:
165
+ generation_config["max_tokens"] = generation_config.pop("max_new_tokens")
166
+
167
+ generation_config["stop"] = generation_config.pop("stop_sequences")
168
+ del generation_config["do_sample"]
169
+ del generation_config["top_k"]
170
+
171
+ is_chat = generation_config.pop("is_chat")
172
+
173
+ return generation_config, is_chat
174
 
175
 
176
  def escape_markdown(text):
 
205
  global HOW_OPENAI_INITIATED
206
 
207
  if model.startswith("openai") or model.startswith("azure"):
 
 
208
  provider, model = model.split("/")
209
 
210
  if provider == "openai":
211
+ from openai import AsyncOpenAI
212
+
213
+ aclient = AsyncOpenAI(api_key=OPENAI_API_KEY)
 
 
 
 
214
 
215
  elif provider == "azure":
216
+ from openai import AsyncAzureOpenAI
 
 
 
 
217
 
218
+ aclient = AsyncAzureOpenAI(
219
+ # https://learn.microsoft.com/azure/ai-services/openai/reference#rest-api-versioning
220
+ api_version="2023-07-01-preview",
221
+ # https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource?pivots=web-portal#create-a-resource
222
+ azure_endpoint=AZURE_OPENAI_ENDPOINT,
223
+ )
224
+
225
+ async def list_models():
226
+ return [model async for model in aclient.models.list()]
227
+
228
+ openai_models = {model_obj.id for model_obj in asyncio.run(list_models())}
229
  assert model in openai_models
 
230
 
231
  @retry(
232
  wait=wait_random_exponential(min=RETRY_MIN_WAIT, max=RETRY_MAX_WAIT),
 
243
  max_tokens = generation_config["max_new_tokens"]
244
 
245
  if model.startswith("gpt") and "instruct" not in model:
246
+ response = await aclient.chat.completions.create(
 
247
  model=model,
248
  messages=[{"role": "user", "content": prompt}],
249
  temperature=temperature,
250
  top_p=top_p,
251
  max_tokens=max_tokens,
252
  )
253
+ assert response.choices[0].message.role == "assistant"
254
+ output = response.choices[0].message.content
255
 
256
  else:
257
+ response = await aclient.completions.create(
 
258
  model=model,
259
  prompt=prompt,
260
  temperature=temperature,
 
351
  )
352
 
353
  async def api_call_function(prompt, generation_config):
354
+ generation_config, _ = prepare_huggingface_generation_config(
355
+ generation_config
356
+ )
357
 
358
+ # TODO: include chat
359
  output = pipe(prompt, return_text=True, **generation_config)[0][
360
  "generated_text"
361
  ]
 
375
  async def api_call_function(prompt, generation_config):
376
  hf_client = AsyncInferenceClient(token=HF_TOKEN, model=model)
377
 
378
+ generation_config, is_chat = prepare_huggingface_generation_config(
379
+ generation_config
 
 
380
  )
 
381
 
382
+ if is_chat:
383
+ messages = [{"role": "user", "content": prompt}]
384
+ response = await hf_client.chat_completion(
385
+ messages, stream=False, **generation_config
386
+ )
387
+ output = response.choices[0].message.content
388
+ length = None
389
+
390
+ else:
391
+ response = await hf_client.text_generation(
392
+ prompt, stream=False, details=True, **generation_config
393
+ )
394
 
395
+ length = (
396
+ len(response.details.prefill) + len(response.details.tokens)
397
+ if response.details is not None
398
+ else None
399
+ )
400
 
401
+ output = response.generated_text
 
 
 
402
 
403
+ # TODO: refactor to support stop of chats
404
  # Remove stop sequences from the output
405
  # Inspired by
406
  # https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py
 
856
  st.caption("Dataset")
857
  st.write(data_card)
858
  try:
859
+ model_info_respose = model_info(model)
860
+ model_card = model_info_respose.cardData
861
+ st.session_state["generation_config"]["is_chat"] = (
862
+ "conversational" in model_info_respose.tags
863
+ )
864
  except (HFValidationError, RepositoryNotFoundError):
865
  pass
866
  else: