Deepinfra / main.py
API-Handler's picture
Upload 5 files
4e2263c verified
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from typing import List, Optional, Any, Dict
from deepinfra_handler import DeepInfraHandler
import json
app = FastAPI()
api_handler = DeepInfraHandler()
class Message(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
model: str
messages: List[Message]
temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0)
max_tokens: Optional[int] = Field(default=4096, ge=1)
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
stop: Optional[List[str]] = Field(default=[])
stream: Optional[bool] = Field(default=False)
@app.post("/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
try:
# Convert request to dictionary
params = request.dict()
if request.stream:
# Handle streaming response
def generate():
for chunk in api_handler.generate_completion(**params):
yield f"data: {json.dumps({'choices': [{'delta': {'content': chunk}}]})}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
generate(),
media_type="text/event-stream"
)
# Handle regular response
response = api_handler.generate_completion(**params)
return response
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)