snailyp commited on
Commit
960a587
·
verified ·
1 Parent(s): 7c15c46

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +125 -122
main.py CHANGED
@@ -1,123 +1,126 @@
1
- from fastapi import FastAPI, HTTPException, Header, Request
2
- from fastapi.middleware.cors import CORSMiddleware
3
- from fastapi.responses import StreamingResponse
4
- from pydantic import BaseModel
5
- import openai
6
- from typing import List, Optional
7
- import logging
8
- from itertools import cycle
9
- import asyncio
10
-
11
- import uvicorn
12
-
13
- from app import config
14
-
15
- # 配置日志
16
- logging.basicConfig(
17
- level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
18
- )
19
- logger = logging.getLogger(__name__)
20
-
21
- app = FastAPI()
22
-
23
- # 允许跨域
24
- app.add_middleware(
25
- CORSMiddleware,
26
- allow_origins=["*"],
27
- allow_credentials=True,
28
- allow_methods=["*"],
29
- allow_headers=["*"],
30
- )
31
-
32
- # API密钥配置
33
- API_KEYS = config.settings.API_KEYS
34
-
35
- # 创建一个循环迭代器
36
- key_cycle = cycle(API_KEYS)
37
- key_lock = asyncio.Lock()
38
-
39
-
40
- class ChatRequest(BaseModel):
41
- messages: List[dict]
42
- model: str = "llama-3.2-90b-text-preview"
43
- temperature: Optional[float] = 0.7
44
- max_tokens: Optional[int] = 8000
45
- stream: Optional[bool] = False
46
-
47
-
48
- async def verify_authorization(authorization: str = Header(None)):
49
- if not authorization:
50
- logger.error("Missing Authorization header")
51
- raise HTTPException(status_code=401, detail="Missing Authorization header")
52
- if not authorization.startswith("Bearer "):
53
- logger.error("Invalid Authorization header format")
54
- raise HTTPException(
55
- status_code=401, detail="Invalid Authorization header format"
56
- )
57
- token = authorization.replace("Bearer ", "")
58
- if token not in config.settings.ALLOWED_TOKENS:
59
- logger.error("Invalid token")
60
- raise HTTPException(status_code=401, detail="Invalid token")
61
- return token
62
-
63
-
64
- @app.get("/v1/models")
65
- async def list_models(authorization: str = Header(None)):
66
- await verify_authorization(authorization)
67
- async with key_lock:
68
- api_key = next(key_cycle)
69
- logger.info(f"Using API key: {api_key[:8]}...")
70
- try:
71
- client = openai.OpenAI(api_key=api_key, base_url=config.settings.BASE_URL)
72
- response = client.models.list()
73
- logger.info("Successfully retrieved models list")
74
- return response
75
- except Exception as e:
76
- logger.error(f"Error listing models: {str(e)}")
77
- raise HTTPException(status_code=500, detail=str(e))
78
-
79
-
80
- @app.post("/v1/chat/completions")
81
- async def chat_completion(request: ChatRequest, authorization: str = Header(None)):
82
- await verify_authorization(authorization)
83
- async with key_lock:
84
- api_key = next(key_cycle)
85
- logger.info(f"Using API key: {api_key[:8]}...")
86
-
87
- try:
88
- logger.info(f"Chat completion request - Model: {request.model}")
89
- client = openai.OpenAI(api_key=api_key, base_url=config.settings.BASE_URL)
90
- response = client.chat.completions.create(
91
- model=request.model,
92
- messages=request.messages,
93
- temperature=request.temperature,
94
- max_tokens=request.max_tokens,
95
- stream=request.stream if hasattr(request, "stream") else False,
96
- )
97
-
98
- if hasattr(request, "stream") and request.stream:
99
- logger.info("Streaming response enabled")
100
-
101
- async def generate():
102
- for chunk in response:
103
- yield f"data: {chunk.model_dump_json()}\n\n"
104
-
105
- return StreamingResponse(content=generate(), media_type="text/event-stream")
106
-
107
- logger.info("Chat completion successful")
108
- return response
109
-
110
- except Exception as e:
111
- logger.error(f"Error in chat completion: {str(e)}")
112
- raise HTTPException(status_code=500, detail=str(e))
113
-
114
-
115
- @app.get("/health")
116
- async def health_check(authorization: str = Header(None)):
117
- await verify_authorization(authorization)
118
- logger.info("Health check endpoint called")
119
- return {"status": "healthy"}
120
-
121
-
122
- if __name__ == "__main__":
 
 
 
123
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ from fastapi import FastAPI, HTTPException, Header, Request
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import StreamingResponse
4
+ from pydantic import BaseModel
5
+ import openai
6
+ from typing import List, Optional
7
+ import logging
8
+ from itertools import cycle
9
+ import asyncio
10
+
11
+ import uvicorn
12
+
13
+ from app import config
14
+
15
+ # 配置日志
16
+ logging.basicConfig(
17
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
18
+ )
19
+ logger = logging.getLogger(__name__)
20
+
21
+ app = FastAPI()
22
+
23
+ # 允许跨域
24
+ app.add_middleware(
25
+ CORSMiddleware,
26
+ allow_origins=["*"],
27
+ allow_credentials=True,
28
+ allow_methods=["*"],
29
+ allow_headers=["*"],
30
+ )
31
+
32
+ # API密钥配置
33
+ API_KEYS = config.settings.API_KEYS
34
+
35
+ # 创建一个循环迭代器
36
+ key_cycle = cycle(API_KEYS)
37
+ key_lock = asyncio.Lock()
38
+
39
+
40
+ class ChatRequest(BaseModel):
41
+ messages: List[dict]
42
+ model: str = "llama-3.2-90b-text-preview"
43
+ temperature: Optional[float] = 0.7
44
+ max_tokens: Optional[int] = 8000
45
+ stream: Optional[bool] = False
46
+
47
+
48
+ async def verify_authorization(authorization: str = Header(None)):
49
+ if not authorization:
50
+ logger.error("Missing Authorization header")
51
+ raise HTTPException(status_code=401, detail="Missing Authorization header")
52
+ if not authorization.startswith("Bearer "):
53
+ logger.error("Invalid Authorization header format")
54
+ raise HTTPException(
55
+ status_code=401, detail="Invalid Authorization header format"
56
+ )
57
+ token = authorization.replace("Bearer ", "")
58
+ if token not in config.settings.ALLOWED_TOKENS:
59
+ logger.error("Invalid token")
60
+ raise HTTPException(status_code=401, detail="Invalid token")
61
+ return token
62
+
63
+
64
+ @app.get("/v1/models")
65
+ @app.get("/hf/v1/models")
66
+ async def list_models(authorization: str = Header(None)):
67
+ await verify_authorization(authorization)
68
+ async with key_lock:
69
+ api_key = next(key_cycle)
70
+ logger.info(f"Using API key: {api_key[:8]}...")
71
+ try:
72
+ client = openai.OpenAI(api_key=api_key, base_url=config.settings.BASE_URL)
73
+ response = client.models.list()
74
+ logger.info("Successfully retrieved models list")
75
+ return response
76
+ except Exception as e:
77
+ logger.error(f"Error listing models: {str(e)}")
78
+ raise HTTPException(status_code=500, detail=str(e))
79
+
80
+
81
+ @app.post("/v1/chat/completions")
82
+ @app.post("/hf/v1/chat/completions")
83
+ async def chat_completion(request: ChatRequest, authorization: str = Header(None)):
84
+ await verify_authorization(authorization)
85
+ async with key_lock:
86
+ api_key = next(key_cycle)
87
+ logger.info(f"Using API key: {api_key[:8]}...")
88
+
89
+ try:
90
+ logger.info(f"Chat completion request - Model: {request.model}")
91
+ client = openai.OpenAI(api_key=api_key, base_url=config.settings.BASE_URL)
92
+ response = client.chat.completions.create(
93
+ model=request.model,
94
+ messages=request.messages,
95
+ temperature=request.temperature,
96
+ max_tokens=request.max_tokens,
97
+ stream=request.stream if hasattr(request, "stream") else False,
98
+ )
99
+
100
+ if hasattr(request, "stream") and request.stream:
101
+ logger.info("Streaming response enabled")
102
+
103
+ async def generate():
104
+ for chunk in response:
105
+ yield f"data: {chunk.model_dump_json()}\n\n"
106
+
107
+ return StreamingResponse(content=generate(), media_type="text/event-stream")
108
+
109
+ logger.info("Chat completion successful")
110
+ return response
111
+
112
+ except Exception as e:
113
+ logger.error(f"Error in chat completion: {str(e)}")
114
+ raise HTTPException(status_code=500, detail=str(e))
115
+
116
+
117
+ @app.get("/health")
118
+ @app.get("/")
119
+ async def health_check():
120
+ await verify_authorization(authorization)
121
+ logger.info("Health check endpoint called")
122
+ return {"status": "healthy"}
123
+
124
+
125
+ if __name__ == "__main__":
126
  uvicorn.run(app, host="0.0.0.0", port=8000)