ApiCheck commited on
Commit
935f280
·
verified ·
1 Parent(s): 3608a72

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +408 -0
app.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Header
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import StreamingResponse
4
+ from pydantic import BaseModel
5
+ import openai
6
+ from typing import List, Optional, Union
7
+ import logging
8
+ import httpx
9
+ import uuid
10
+ import time
11
+ import json
12
+ from datetime import datetime, timezone
13
+ import requests
14
+ import uvicorn
15
+ import random
16
+
17
+ logging.basicConfig(
18
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
19
+ )
20
+ logger = logging.getLogger(__name__)
21
+
22
+ app = FastAPI()
23
+
24
+ app.add_middleware(
25
+ CORSMiddleware,
26
+ allow_origins=["*"],
27
+ allow_credentials=True,
28
+ allow_methods=["*"],
29
+ allow_headers=["*"],
30
+ )
31
+
32
+ MAX_RETRIES = 3
33
+
34
+ class ChatRequest(BaseModel):
35
+ messages: List[dict]
36
+ model: str
37
+ temperature: Optional[float] = 0.7
38
+ stream: Optional[bool] = False
39
+ tools: Optional[List[dict]] = []
40
+ tool_choice: Optional[str] = "auto"
41
+
42
+ class EmbeddingRequest(BaseModel):
43
+ input: Union[str, List[str]]
44
+ model: str
45
+ encoding_format: Optional[str] = "float"
46
+
47
+ async def verify_authorization(authorization: str = Header(None)):
48
+ print("Authorization header:", authorization)
49
+ if not authorization:
50
+ logger.error("Missing Authorization header")
51
+ raise HTTPException(status_code=401, detail="Missing Authorization header")
52
+ if not authorization.startswith("Bearer "):
53
+ logger.error("Invalid Authorization header format")
54
+ raise HTTPException(
55
+ status_code=401, detail="Invalid Authorization header format"
56
+ )
57
+ token = authorization.replace("Bearer ", "")
58
+ return token
59
+
60
+ def get_openai_models(api_keys):
61
+ api_key = random.choice(api_keys)
62
+ try:
63
+ client = openai.OpenAI(api_key=api_key)
64
+ models = client.models.list()
65
+ return models.model_dump()
66
+ except Exception as e:
67
+ logger.error(f"Error getting models from OpenAI with key {api_key}: {e}")
68
+ return {"error": str(e)}
69
+
70
+ def get_gemini_models(api_keys):
71
+ api_key = random.choice(api_keys)
72
+ base_url = "https://generativelanguage.googleapis.com/v1beta"
73
+ url = f"{base_url}/models?key={api_key}"
74
+
75
+ try:
76
+ response = requests.get(url)
77
+ if response.status_code == 200:
78
+ gemini_models = response.json()
79
+ return convert_to_openai_models_format(gemini_models)
80
+ else:
81
+ logger.error(f"Error getting models from Gemini with key {api_key}: {response.status_code} - {response.text}")
82
+ return {"error": f"Gemini API error: {response.status_code} - {response.text}"}
83
+
84
+ except requests.RequestException as e:
85
+ logger.error(f"Request failed: {e}")
86
+ return {"error": f"Request failed: {e}"}
87
+
88
+ def convert_to_openai_models_format(gemini_models):
89
+ openai_format = {"object": "list", "data": []}
90
+
91
+ for model in gemini_models.get("models", []):
92
+ openai_model = {
93
+ "id": model["name"].split("/")[-1],
94
+ "object": "model",
95
+ "created": int(datetime.now(timezone.utc).timestamp()),
96
+ "owned_by": "google",
97
+ "permission": [],
98
+ "root": model["name"],
99
+ "parent": None,
100
+ }
101
+ openai_format["data"].append(openai_model)
102
+
103
+ return openai_format
104
+
105
+ def convert_messages_to_gemini_format(messages):
106
+ gemini_messages = []
107
+ for msg in messages:
108
+ role = "user" if msg["role"] == "user" else "model"
109
+ parts = []
110
+ if isinstance(msg["content"], str):
111
+ parts.append({"text": msg["content"]})
112
+ elif isinstance(msg["content"], list):
113
+ for content in msg["content"]:
114
+ if isinstance(content, str):
115
+ parts.append({"text": content})
116
+ elif isinstance(content, dict) and content["type"] == "text":
117
+ parts.append({"text": content["text"]})
118
+ elif isinstance(content, dict) and content["type"] == "image_url":
119
+ image_url = content["image_url"]["url"]
120
+ if image_url.startswith("data:image"):
121
+ parts.append(
122
+ {
123
+ "inline_data": {
124
+ "mime_type": "image/jpeg",
125
+ "data": image_url.split(",")[1],
126
+ }
127
+ }
128
+ )
129
+ else:
130
+ parts.append(
131
+ {
132
+ "image_url": {
133
+ "url": image_url,
134
+ }
135
+ }
136
+ )
137
+ gemini_messages.append({"role": role, "parts": parts})
138
+ return gemini_messages
139
+
140
+ async def convert_gemini_response_to_openai(response, model, stream=False):
141
+ if stream:
142
+ chunk = response
143
+ if not chunk["candidates"]:
144
+ return None
145
+
146
+ return {
147
+ "id": "chatcmpl-" + str(uuid.uuid4()),
148
+ "object": "chat.completion.chunk",
149
+ "created": int(time.time()),
150
+ "model": model,
151
+ "choices": [
152
+ {
153
+ "index": 0,
154
+ "delta": {
155
+ "content": chunk["candidates"][0]["content"]["parts"][0]["text"]
156
+ },
157
+ "finish_reason": None,
158
+ }
159
+ ],
160
+ }
161
+ else:
162
+ content = response["candidates"][0]["content"]["parts"][0]["text"]
163
+ return {
164
+ "id": "chatcmpl-" + str(uuid.uuid4()),
165
+ "object": "chat.completion",
166
+ "created": int(time.time()),
167
+ "model": model,
168
+ "choices": [
169
+ {
170
+ "index": 0,
171
+ "message": {
172
+ "role": "assistant",
173
+ "content": content,
174
+ },
175
+ "finish_reason": "stop",
176
+ }
177
+ ],
178
+ "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
179
+ }
180
+
181
+ @app.get("/v1/models")
182
+ @app.get("/hf/v1/models")
183
+ async def list_models(authorization: str = Header(None)):
184
+ token = await verify_authorization(authorization)
185
+ api_keys = [key.strip() for key in token.split(',')]
186
+
187
+ all_models = []
188
+ error_messages = []
189
+
190
+ for api_key in api_keys:
191
+ if api_key.startswith("sk-"):
192
+ response = get_openai_models([api_key])
193
+ else:
194
+ response = get_gemini_models([api_key])
195
+
196
+ if "error" in response:
197
+ error_messages.append(response["error"])
198
+ else:
199
+ if isinstance(response, dict) and 'data' in response:
200
+ all_models.extend(response['data'])
201
+ else:
202
+ logger.warning(f"Unexpected response format from model list API for key {api_key}: {response}")
203
+
204
+ if error_messages and not all_models:
205
+ raise HTTPException(status_code=500, detail=f"Errors encountered: {', '.join(error_messages)}")
206
+
207
+ return {"data": all_models, "object": "list"}
208
+
209
+ @app.post("/v1/chat/completions")
210
+ @app.post("/hf/v1/chat/completions")
211
+ async def chat_completion(request: ChatRequest, authorization: str = Header(None)):
212
+ token = await verify_authorization(authorization)
213
+ api_keys = [key.strip() for key in token.split(',')]
214
+ logger.info(f"Chat completion request - Model: {request.model}")
215
+
216
+ retries = 0
217
+
218
+ while retries < MAX_RETRIES:
219
+ api_key = random.choice(api_keys)
220
+ try:
221
+ logger.info(f"Attempt {retries + 1} with API key: {api_key}")
222
+
223
+ if api_key.startswith("sk-"):
224
+ client = openai.OpenAI(api_key=api_key)
225
+
226
+ if request.stream:
227
+ logger.info("Streaming response enabled")
228
+
229
+ async def generate():
230
+ try:
231
+ stream_response = client.chat.completions.create(
232
+ model=request.model,
233
+ messages=request.messages,
234
+ temperature=request.temperature,
235
+ stream=True,
236
+ )
237
+
238
+ for chunk in stream_response:
239
+ chunk_json = chunk.model_dump_json()
240
+ yield f"data: {chunk_json}\n\n"
241
+ yield "data: [DONE]\n\n"
242
+ except Exception as e:
243
+ logger.error(f"Stream error: {str(e)}")
244
+ raise
245
+
246
+ return StreamingResponse(content=generate(), media_type="text/event-stream")
247
+
248
+ else:
249
+ response = client.chat.completions.create(
250
+ model=request.model,
251
+ messages=request.messages,
252
+ temperature=request.temperature,
253
+ )
254
+ logger.info("Chat completion successful")
255
+ return response.model_dump()
256
+ else:
257
+ gemini_messages = convert_messages_to_gemini_format(request.messages)
258
+ payload = {
259
+ "contents": gemini_messages,
260
+ "generationConfig": {
261
+ "temperature": request.temperature,
262
+ }
263
+ }
264
+
265
+ if request.stream:
266
+ logger.info("Streaming response enabled")
267
+
268
+ async def generate():
269
+ nonlocal api_key, retries, api_keys
270
+
271
+ while retries < MAX_RETRIES:
272
+ try:
273
+ async with httpx.AsyncClient() as client:
274
+ stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{request.model}:streamGenerateContent?alt=sse&key={api_key}"
275
+ async with client.stream("POST", stream_url, json=payload, timeout=60.0) as response:
276
+ if response.status_code == 429:
277
+ logger.warning(f"Rate limit reached for key: {api_key}")
278
+ retries += 1
279
+ if retries >= MAX_RETRIES:
280
+ yield f"data: {json.dumps({'error': 'Max retries reached'})}\n\n"
281
+ break
282
+
283
+ api_keys.remove(api_key)
284
+ if not api_keys:
285
+ yield f"data: {json.dumps({'error': 'All API keys exhausted'})}\n\n"
286
+ break
287
+
288
+ api_key = random.choice(api_keys)
289
+ logger.info(f"Retrying with a new API key: {api_key}")
290
+ continue
291
+
292
+ if response.status_code != 200:
293
+ logger.error(f"Error in streaming response with key {api_key}: {response.status_code} - {response.text}")
294
+
295
+ retries += 1
296
+ if retries >= MAX_RETRIES:
297
+ yield f"data: {json.dumps({'error': 'Max retries reached'})}\n\n"
298
+ break
299
+
300
+ api_keys.remove(api_key)
301
+ if not api_keys:
302
+ yield f"data: {json.dumps({'error': 'All API keys exhausted'})}\n\n"
303
+ break
304
+
305
+ api_key = random.choice(api_keys)
306
+ logger.info(f"Retrying with a new API key: {api_key}")
307
+ continue
308
+
309
+ async for line in response.aiter_lines():
310
+ if line.startswith("data: "):
311
+ try:
312
+ chunk = json.loads(line[6:])
313
+ if not chunk.get("candidates"):
314
+ continue
315
+
316
+ content = chunk["candidates"][0]["content"]["parts"][0]["text"]
317
+
318
+ new_chunk = {
319
+ "id": "chatcmpl-" + str(uuid.uuid4()),
320
+ "object": "chat.completion.chunk",
321
+ "created": int(time.time()),
322
+ "model": request.model,
323
+ "choices": [
324
+ {
325
+ "index": 0,
326
+ "delta": {
327
+ "content": content
328
+ },
329
+ "finish_reason": None,
330
+ }
331
+ ],
332
+ }
333
+ yield f"data: {json.dumps(new_chunk)}\n\n"
334
+
335
+ except json.JSONDecodeError:
336
+ continue
337
+ yield "data: [DONE]\n\n"
338
+ return
339
+ except Exception as e:
340
+ logger.error(f"Stream error: {str(e)}")
341
+ retries += 1
342
+ if retries >= MAX_RETRIES:
343
+ yield f"data: {json.dumps({'error': 'Max retries reached'})}\n\n"
344
+ break
345
+
346
+ api_keys.remove(api_key)
347
+ if not api_keys:
348
+ yield f"data: {json.dumps({'error': 'All API keys exhausted'})}\n\n"
349
+ break
350
+
351
+ api_key = random.choice(api_keys)
352
+ logger.info(f"Retrying with a new API key: {api_key}")
353
+ continue
354
+
355
+ return StreamingResponse(content=generate(), media_type="text/event-stream")
356
+ else:
357
+ async with httpx.AsyncClient() as client:
358
+ non_stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{request.model}:generateContent?key={api_key}"
359
+ response = await client.post(non_stream_url, json=payload)
360
+
361
+ if response.status_code != 200:
362
+ logger.error(f"Error in non-streaming response with key {api_key}: {response.status_code} - {response.text}")
363
+
364
+ retries += 1
365
+ if retries >= MAX_RETRIES:
366
+ raise HTTPException(status_code=500, detail="Max retries reached")
367
+
368
+ api_keys.remove(api_key)
369
+ if not api_keys:
370
+ raise HTTPException(status_code=500, detail="All API keys exhausted")
371
+
372
+ api_key = random.choice(api_keys)
373
+ logger.info(f"Retrying with a new API key: {api_key}")
374
+ continue
375
+
376
+ gemini_response = response.json()
377
+ logger.info("Chat completion successful")
378
+ return await convert_gemini_response_to_openai(gemini_response, request.model)
379
+
380
+ except Exception as e:
381
+ logger.error(f"Error in chat completion: {str(e)}")
382
+ if isinstance(e, HTTPException):
383
+ raise e
384
+
385
+ retries += 1
386
+ if retries >= MAX_RETRIES:
387
+ logger.error("Max retries reached, giving up")
388
+ raise HTTPException(status_code=500, detail="Max retries reached")
389
+
390
+ api_keys.remove(api_key)
391
+ if not api_keys:
392
+ raise HTTPException(status_code=500, detail="All API keys exhausted")
393
+
394
+ api_key = random.choice(api_keys)
395
+ logger.info(f"Retrying with a new API key: {api_key}")
396
+ continue
397
+
398
+ raise HTTPException(status_code=500, detail="Unexpected error in chat completion")
399
+
400
+
401
+ @app.get("/health")
402
+ @app.get("/")
403
+ async def health_check():
404
+ logger.info("Health check endpoint called")
405
+ return {"status": "healthy"}
406
+
407
+ if __name__ == "__main__":
408
+ uvicorn.run(app, host="0.0.0.0", port=8080)