Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import StreamingResponse | |
from pydantic import BaseModel, Field | |
from typing import List, Optional, Any, Dict | |
from deepinfra_handler import DeepInfraHandler | |
import json | |
app = FastAPI() | |
api_handler = DeepInfraHandler() | |
class Message(BaseModel): | |
role: str | |
content: str | |
class ChatCompletionRequest(BaseModel): | |
model: str | |
messages: List[Message] | |
temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0) | |
max_tokens: Optional[int] = Field(default=4096, ge=1) | |
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0) | |
frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0) | |
presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0) | |
stop: Optional[List[str]] = Field(default=[]) | |
stream: Optional[bool] = Field(default=False) | |
async def chat_completions(request: ChatCompletionRequest): | |
try: | |
# Convert request to dictionary | |
params = request.dict() | |
if request.stream: | |
# Handle streaming response | |
def generate(): | |
for chunk in api_handler.generate_completion(**params): | |
yield f"data: {json.dumps({'choices': [{'delta': {'content': chunk}}]})}\n\n" | |
yield "data: [DONE]\n\n" | |
return StreamingResponse( | |
generate(), | |
media_type="text/event-stream" | |
) | |
# Handle regular response | |
response = api_handler.generate_completion(**params) | |
return response | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) |