fastapi-mixtral-46.7b / main.py.bak
OjciecTadeusz's picture
Rename main.py to main.py.bak
2c0fb7d verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from huggingface_hub import InferenceClient
import uvicorn
from typing import List, Optional
app = FastAPI()
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
class ChatMessage(BaseModel):
role: str
content: str
class GenerationRequest(BaseModel):
prompt: str
message: Optional[str] = None
system_message: Optional[str] = None
history: Optional[List[ChatMessage]] = None
temperature: Optional[float] = 0.7
top_p: Optional[float] = 0.95
def format_prompt(message: str, history: List[ChatMessage] = None, system_message: str = None) -> str:
prompt = "<s>"
# Add system message if provided
if system_message:
prompt += f"[INST] {system_message} [/INST]</s>"
# Add conversation history
if history:
for msg in history:
if msg.role == "user":
prompt += f"[INST] {msg.content} [/INST]"
else:
prompt += f" {msg.content}</s>"
# Add the current message
prompt += f"[INST] {message} [/INST]"
return prompt
@app.post("/generate/")
async def generate_text(request: GenerationRequest):
try:
# Use either prompt or message
message = request.prompt if request.prompt else request.message
if not message:
raise HTTPException(status_code=400, detail="Either 'prompt' or 'message' must be provided")
# Format the prompt with history and system message if provided
formatted_prompt = format_prompt(
message=message,
history=request.history,
system_message=request.system_message
)
# Generate response
params = {
"temperature": max(request.temperature, 0.01), # Ensure temperature isn't too low
"max_new_tokens": 1048,
"top_p": request.top_p,
"repetition_penalty": 1.0,
"do_sample": True,
"seed": 42
}
# Generate the response - handling the response as a single string
response = client.text_generation(
formatted_prompt,
**params
)
# The response is now directly a string
return {"response": response}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)