from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoModelForCausalLM, AutoProcessor from PIL import Image import requests import torch # Define the FastAPI app app = FastAPI() # Initialize model and processor at startup processor = AutoProcessor.from_pretrained('allenai/Molmo-7B-D-0924', trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained('allenai/Molmo-7B-D-0924', trust_remote_code=True) # Move the model to GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # Request body structure class GenerateRequest(BaseModel): image_url: str text_input: str # API root endpoint @app.get("/") def root(): return {"message": "Molmo-7B-D API is up and running!"} # Text generation endpoint @app.post("/generate/") def generate_text(request: GenerateRequest): try: # Fetch image from URL response = requests.get(request.image_url, stream=True) image = Image.open(response.raw) # Preprocess inputs inputs = processor(images=[image], text=request.text_input, return_tensors="pt").to(device) # Generate text output_ids = model.generate(inputs["input_ids"], max_new_tokens=200) generated_text = processor.tokenizer.decode(output_ids[0], skip_special_tokens=True) return {"generated_text": generated_text} except Exception as e: raise HTTPException(status_code=500, detail=str(e))