|
import base64 |
|
|
|
from fastapi import FastAPI, HTTPException |
|
from PIL import Image |
|
from pydantic import BaseModel |
|
from tempfile import NamedTemporaryFile |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
app = FastAPI() |
|
|
|
|
|
class RequestData(BaseModel): |
|
prompt: str |
|
image: str |
|
|
|
|
|
def load_model(): |
|
model_id = "models" |
|
revision = "2024-08-26" |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, trust_remote_code=True, revision=revision |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision) |
|
return model, tokenizer |
|
|
|
|
|
|
|
MODEL, TOKENIZER = load_model() |
|
print("INFO: Model & Tokenizer loaded!") |
|
|
|
@app.get("/") |
|
def greet_json(): |
|
return {"message": "Server is UP!"} |
|
|
|
|
|
@app.post("/query") |
|
def query(data: RequestData): |
|
prompt = data.prompt |
|
image = data.image |
|
print(f"INFO: prompt - {prompt}") |
|
|
|
try: |
|
|
|
image = base64.b64decode(image) |
|
|
|
with NamedTemporaryFile(delete=True, suffix=".png") as temp_image: |
|
temp_image.write(image) |
|
temp_image.flush() |
|
|
|
image = Image.open(temp_image.name) |
|
enc_image = MODEL.encode_image(image) |
|
response = MODEL.answer_question(enc_image, str(prompt), TOKENIZER) |
|
|
|
return {"response": str(response)} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|