Spaces:
Sleeping
Sleeping
shlomihod
commited on
Commit
·
874ee5e
1
Parent(s):
cb99910
switch together to openai api
Browse files
app.py
CHANGED
@@ -204,8 +204,11 @@ def reload_module(name):
|
|
204 |
def build_api_call_function(model):
|
205 |
global HOW_OPENAI_INITIATED
|
206 |
|
207 |
-
if
|
208 |
-
|
|
|
|
|
|
|
209 |
|
210 |
if provider == "openai":
|
211 |
from openai import AsyncOpenAI
|
@@ -222,11 +225,20 @@ def build_api_call_function(model):
|
|
222 |
azure_endpoint=AZURE_OPENAI_ENDPOINT,
|
223 |
)
|
224 |
|
225 |
-
|
226 |
-
|
|
|
|
|
|
|
|
|
227 |
|
228 |
-
|
229 |
-
|
|
|
|
|
|
|
|
|
|
|
230 |
|
231 |
@retry(
|
232 |
wait=wait_random_exponential(min=RETRY_MIN_WAIT, max=RETRY_MAX_WAIT),
|
@@ -242,7 +254,9 @@ def build_api_call_function(model):
|
|
242 |
top_p = generation_config["top_p"] if generation_config["do_sample"] else 1
|
243 |
max_tokens = generation_config["max_new_tokens"]
|
244 |
|
245 |
-
if
|
|
|
|
|
246 |
response = await aclient.chat.completions.create(
|
247 |
model=model,
|
248 |
messages=[{"role": "user", "content": prompt}],
|
@@ -270,49 +284,6 @@ def build_api_call_function(model):
|
|
270 |
|
271 |
return output, length
|
272 |
|
273 |
-
elif model.startswith("together"):
|
274 |
-
TOGETHER_API_ENDPOINT = "https://api.together.xyz/inference"
|
275 |
-
|
276 |
-
provider, model = model.split("/", maxsplit=1)
|
277 |
-
|
278 |
-
@retry(
|
279 |
-
wait=wait_random_exponential(min=RETRY_MIN_WAIT, max=RETRY_MAX_WAIT),
|
280 |
-
stop=stop_after_attempt(RETRY_MAX_ATTEMPTS),
|
281 |
-
reraise=True,
|
282 |
-
)
|
283 |
-
async def api_call_function(prompt, generation_config):
|
284 |
-
headers = {
|
285 |
-
"Authorization": f"Bearer {TOGETHER_API_KEY}",
|
286 |
-
"User-Agent": "FM",
|
287 |
-
}
|
288 |
-
|
289 |
-
payload = {
|
290 |
-
"temperature": generation_config["temperature"]
|
291 |
-
if generation_config["do_sample"]
|
292 |
-
else 0,
|
293 |
-
"top_p": generation_config["top_p"]
|
294 |
-
if generation_config["do_sample"]
|
295 |
-
else 1,
|
296 |
-
"top_k": generation_config["top_k"]
|
297 |
-
if generation_config["do_sample"]
|
298 |
-
else 0,
|
299 |
-
"max_tokens": generation_config["max_new_tokens"],
|
300 |
-
"prompt": prompt,
|
301 |
-
"model": model,
|
302 |
-
"stop": generation_config["stop_sequences"],
|
303 |
-
}
|
304 |
-
|
305 |
-
LOGGER.info(f"{payload=}")
|
306 |
-
|
307 |
-
async with aiohttp.ClientSession() as session:
|
308 |
-
async with session.post(
|
309 |
-
TOGETHER_API_ENDPOINT, json=payload, headers=headers
|
310 |
-
) as response:
|
311 |
-
output = (await response.json())["output"]["choices"][0]["text"]
|
312 |
-
length = None
|
313 |
-
|
314 |
-
return output, length
|
315 |
-
|
316 |
elif model.startswith("cohere"):
|
317 |
_, model = model.split("/")
|
318 |
|
|
|
204 |
def build_api_call_function(model):
|
205 |
global HOW_OPENAI_INITIATED
|
206 |
|
207 |
+
if any(
|
208 |
+
model.startswith(known_providers)
|
209 |
+
for known_providers in ("openai", "azure", "together")
|
210 |
+
):
|
211 |
+
provider, model = model.split("/", maxsplit=1)
|
212 |
|
213 |
if provider == "openai":
|
214 |
from openai import AsyncOpenAI
|
|
|
225 |
azure_endpoint=AZURE_OPENAI_ENDPOINT,
|
226 |
)
|
227 |
|
228 |
+
elif provider == "together":
|
229 |
+
from openai import AsyncOpenAI
|
230 |
+
|
231 |
+
aclient = AsyncOpenAI(
|
232 |
+
api_key=TOGETHER_API_KEY, base_url="https://api.together.xyz/v1"
|
233 |
+
)
|
234 |
|
235 |
+
if provider in ("openai", "azure"):
|
236 |
+
|
237 |
+
async def list_models():
|
238 |
+
return [model async for model in aclient.models.list()]
|
239 |
+
|
240 |
+
openai_models = {model_obj.id for model_obj in asyncio.run(list_models())}
|
241 |
+
assert model in openai_models
|
242 |
|
243 |
@retry(
|
244 |
wait=wait_random_exponential(min=RETRY_MIN_WAIT, max=RETRY_MAX_WAIT),
|
|
|
254 |
top_p = generation_config["top_p"] if generation_config["do_sample"] else 1
|
255 |
max_tokens = generation_config["max_new_tokens"]
|
256 |
|
257 |
+
if (
|
258 |
+
model.startswith("gpt") and "instruct" not in model
|
259 |
+
) or provider == "together":
|
260 |
response = await aclient.chat.completions.create(
|
261 |
model=model,
|
262 |
messages=[{"role": "user", "content": prompt}],
|
|
|
284 |
|
285 |
return output, length
|
286 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
287 |
elif model.startswith("cohere"):
|
288 |
_, model = model.split("/")
|
289 |
|