Deepinfra / main.py
API-Handler's picture
Upload 5 files
4e2263c verified
raw
history blame
1.86 kB
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from typing import List, Optional, Any, Dict
from deepinfra_handler import DeepInfraHandler
import json
app = FastAPI()
api_handler = DeepInfraHandler()
class Message(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
model: str
messages: List[Message]
temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0)
max_tokens: Optional[int] = Field(default=4096, ge=1)
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
stop: Optional[List[str]] = Field(default=[])
stream: Optional[bool] = Field(default=False)
@app.post("/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
try:
# Convert request to dictionary
params = request.dict()
if request.stream:
# Handle streaming response
def generate():
for chunk in api_handler.generate_completion(**params):
yield f"data: {json.dumps({'choices': [{'delta': {'content': chunk}}]})}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
generate(),
media_type="text/event-stream"
)
# Handle regular response
response = api_handler.generate_completion(**params)
return response
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)