|
from fastapi import FastAPI
|
|
from pydantic import BaseModel
|
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
|
import uvicorn
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
model_name = "bigscience/mt0-base"
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
class GenerationRequest(BaseModel):
|
|
prompt: str
|
|
max_tokens: int = 100
|
|
|
|
@app.post("/generate")
|
|
async def generate(request: GenerationRequest):
|
|
inputs = tokenizer(request.prompt, return_tensors="pt", padding=True, truncation=True)
|
|
|
|
|
|
device = model.device
|
|
inputs = {k: v.to(device) for k, v in inputs.items()}
|
|
|
|
outputs = model.generate(**inputs, max_new_tokens=request.max_tokens)
|
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
return {"generated_text": generated_text}
|
|
|
|
@app.get("/")
|
|
def home():
|
|
return {"message": "Welcome to the Text Generation API"}
|
|
|
|
|