from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from pydantic import BaseModel, Field from typing import List, Optional, Any, Dict from deepinfra_handler import DeepInfraHandler import json app = FastAPI() api_handler = DeepInfraHandler() class Message(BaseModel): role: str content: str class ChatCompletionRequest(BaseModel): model: str messages: List[Message] temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0) max_tokens: Optional[int] = Field(default=4096, ge=1) top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0) frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0) presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0) stop: Optional[List[str]] = Field(default=[]) stream: Optional[bool] = Field(default=False) @app.post("/chat/completions") async def chat_completions(request: ChatCompletionRequest): try: # Convert request to dictionary params = request.dict() if request.stream: # Handle streaming response def generate(): for chunk in api_handler.generate_completion(**params): yield f"data: {json.dumps({'choices': [{'delta': {'content': chunk}}]})}\n\n" yield "data: [DONE]\n\n" return StreamingResponse( generate(), media_type="text/event-stream" ) # Handle regular response response = api_handler.generate_completion(**params) return response except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)