shlomihod commited on
Commit
874ee5e
·
1 Parent(s): cb99910

switch together to openai api

Browse files
Files changed (1) hide show
  1. app.py +21 -50
app.py CHANGED
@@ -204,8 +204,11 @@ def reload_module(name):
204
  def build_api_call_function(model):
205
  global HOW_OPENAI_INITIATED
206
 
207
- if model.startswith("openai") or model.startswith("azure"):
208
- provider, model = model.split("/")
 
 
 
209
 
210
  if provider == "openai":
211
  from openai import AsyncOpenAI
@@ -222,11 +225,20 @@ def build_api_call_function(model):
222
  azure_endpoint=AZURE_OPENAI_ENDPOINT,
223
  )
224
 
225
- async def list_models():
226
- return [model async for model in aclient.models.list()]
 
 
 
 
227
 
228
- openai_models = {model_obj.id for model_obj in asyncio.run(list_models())}
229
- assert model in openai_models
 
 
 
 
 
230
 
231
  @retry(
232
  wait=wait_random_exponential(min=RETRY_MIN_WAIT, max=RETRY_MAX_WAIT),
@@ -242,7 +254,9 @@ def build_api_call_function(model):
242
  top_p = generation_config["top_p"] if generation_config["do_sample"] else 1
243
  max_tokens = generation_config["max_new_tokens"]
244
 
245
- if model.startswith("gpt") and "instruct" not in model:
 
 
246
  response = await aclient.chat.completions.create(
247
  model=model,
248
  messages=[{"role": "user", "content": prompt}],
@@ -270,49 +284,6 @@ def build_api_call_function(model):
270
 
271
  return output, length
272
 
273
- elif model.startswith("together"):
274
- TOGETHER_API_ENDPOINT = "https://api.together.xyz/inference"
275
-
276
- provider, model = model.split("/", maxsplit=1)
277
-
278
- @retry(
279
- wait=wait_random_exponential(min=RETRY_MIN_WAIT, max=RETRY_MAX_WAIT),
280
- stop=stop_after_attempt(RETRY_MAX_ATTEMPTS),
281
- reraise=True,
282
- )
283
- async def api_call_function(prompt, generation_config):
284
- headers = {
285
- "Authorization": f"Bearer {TOGETHER_API_KEY}",
286
- "User-Agent": "FM",
287
- }
288
-
289
- payload = {
290
- "temperature": generation_config["temperature"]
291
- if generation_config["do_sample"]
292
- else 0,
293
- "top_p": generation_config["top_p"]
294
- if generation_config["do_sample"]
295
- else 1,
296
- "top_k": generation_config["top_k"]
297
- if generation_config["do_sample"]
298
- else 0,
299
- "max_tokens": generation_config["max_new_tokens"],
300
- "prompt": prompt,
301
- "model": model,
302
- "stop": generation_config["stop_sequences"],
303
- }
304
-
305
- LOGGER.info(f"{payload=}")
306
-
307
- async with aiohttp.ClientSession() as session:
308
- async with session.post(
309
- TOGETHER_API_ENDPOINT, json=payload, headers=headers
310
- ) as response:
311
- output = (await response.json())["output"]["choices"][0]["text"]
312
- length = None
313
-
314
- return output, length
315
-
316
  elif model.startswith("cohere"):
317
  _, model = model.split("/")
318
 
 
204
  def build_api_call_function(model):
205
  global HOW_OPENAI_INITIATED
206
 
207
+ if any(
208
+ model.startswith(known_providers)
209
+ for known_providers in ("openai", "azure", "together")
210
+ ):
211
+ provider, model = model.split("/", maxsplit=1)
212
 
213
  if provider == "openai":
214
  from openai import AsyncOpenAI
 
225
  azure_endpoint=AZURE_OPENAI_ENDPOINT,
226
  )
227
 
228
+ elif provider == "together":
229
+ from openai import AsyncOpenAI
230
+
231
+ aclient = AsyncOpenAI(
232
+ api_key=TOGETHER_API_KEY, base_url="https://api.together.xyz/v1"
233
+ )
234
 
235
+ if provider in ("openai", "azure"):
236
+
237
+ async def list_models():
238
+ return [model async for model in aclient.models.list()]
239
+
240
+ openai_models = {model_obj.id for model_obj in asyncio.run(list_models())}
241
+ assert model in openai_models
242
 
243
  @retry(
244
  wait=wait_random_exponential(min=RETRY_MIN_WAIT, max=RETRY_MAX_WAIT),
 
254
  top_p = generation_config["top_p"] if generation_config["do_sample"] else 1
255
  max_tokens = generation_config["max_new_tokens"]
256
 
257
+ if (
258
+ model.startswith("gpt") and "instruct" not in model
259
+ ) or provider == "together":
260
  response = await aclient.chat.completions.create(
261
  model=model,
262
  messages=[{"role": "user", "content": prompt}],
 
284
 
285
  return output, length
286
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  elif model.startswith("cohere"):
288
  _, model = model.split("/")
289