Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from huggingface_hub import InferenceClient | |
import uvicorn | |
from typing import List, Optional | |
app = FastAPI() | |
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1") | |
class ChatMessage(BaseModel): | |
role: str | |
content: str | |
class GenerationRequest(BaseModel): | |
prompt: str | |
message: Optional[str] = None | |
system_message: Optional[str] = None | |
history: Optional[List[ChatMessage]] = None | |
temperature: Optional[float] = 0.7 | |
top_p: Optional[float] = 0.95 | |
def format_prompt(message: str, history: List[ChatMessage] = None, system_message: str = None) -> str: | |
prompt = "<s>" | |
# Add system message if provided | |
if system_message: | |
prompt += f"[INST] {system_message} [/INST]</s>" | |
# Add conversation history | |
if history: | |
for msg in history: | |
if msg.role == "user": | |
prompt += f"[INST] {msg.content} [/INST]" | |
else: | |
prompt += f" {msg.content}</s>" | |
# Add the current message | |
prompt += f"[INST] {message} [/INST]" | |
return prompt | |
async def generate_text(request: GenerationRequest): | |
try: | |
# Use either prompt or message | |
message = request.prompt if request.prompt else request.message | |
if not message: | |
raise HTTPException(status_code=400, detail="Either 'prompt' or 'message' must be provided") | |
# Format the prompt with history and system message if provided | |
formatted_prompt = format_prompt( | |
message=message, | |
history=request.history, | |
system_message=request.system_message | |
) | |
# Generate response | |
params = { | |
"temperature": max(request.temperature, 0.01), # Ensure temperature isn't too low | |
"max_new_tokens": 1048, | |
"top_p": request.top_p, | |
"repetition_penalty": 1.0, | |
"do_sample": True, | |
"seed": 42 | |
} | |
# Generate the response - handling the response as a single string | |
response = client.text_generation( | |
formatted_prompt, | |
**params | |
) | |
# The response is now directly a string | |
return {"response": response} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=8000) |