from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import JSONResponse from pydantic import BaseModel from typing import Optional import base64 import io from PIL import Image import torch import numpy as np import os import logging # Existing imports from utils import ( check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img, ) from ultralytics import YOLO from transformers import AutoProcessor, AutoModelForCausalLM # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # main.py (YOLO loading fix) from utils import get_yolo_model import torch # Load YOLO model using official method yolo_model = get_yolo_model(model_path="weights/icon_detect/best.pt") # Handle device placement device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if str(device) == "cuda": yolo_model = yolo_model.cuda() else: yolo_model = yolo_model.cpu() # Load caption model and processor try: processor = AutoProcessor.from_pretrained( "microsoft/Florence-2-base", trust_remote_code=True ) model = AutoModelForCausalLM.from_pretrained( "weights/icon_caption_florence", torch_dtype=torch.float16, trust_remote_code=True, ).to("cuda") except Exception as e: logger.warning(f"Failed to load caption model on GPU: {e}. Falling back to CPU.") model = AutoModelForCausalLM.from_pretrained( "weights/icon_caption_florence", torch_dtype=torch.float16, trust_remote_code=True, ) caption_model_processor = {"processor": processor, "model": model} logger.info("Finished loading models!!!") app = FastAPI() class ProcessResponse(BaseModel): image: str # Base64 encoded image parsed_content_list: str label_coordinates: str def process( image_input: Image.Image, box_threshold: float, iou_threshold: float ) -> ProcessResponse: try: # Save the input image temporarily image_save_path = "imgs/saved_image_demo.png" os.makedirs(os.path.dirname(image_save_path), exist_ok=True) image_input.save(image_save_path) image = Image.open(image_save_path) # Calculate box overlay ratio box_overlay_ratio = image.size[0] / 3200 draw_bbox_config = { "text_scale": 0.8 * box_overlay_ratio, "text_thickness": max(int(2 * box_overlay_ratio), 1), "text_padding": max(int(3 * box_overlay_ratio), 1), "thickness": max(int(3 * box_overlay_ratio), 1), } # Perform OCR and get bounding boxes ocr_bbox_rslt, is_goal_filtered = check_ocr_box( image_save_path, display_img=False, output_bb_format="xyxy", goal_filtering=None, easyocr_args={"paragraph": False, "text_threshold": 0.9}, use_paddleocr=True, ) text, ocr_bbox = ocr_bbox_rslt # Get labeled image and coordinates dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img( image_save_path, yolo_model, BOX_TRESHOLD=box_threshold, output_coord_in_ratio=True, ocr_bbox=ocr_bbox, draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text, iou_threshold=iou_threshold, ) # Ensure dino_labled_img is a base64-encoded string if isinstance(dino_labled_img, bytes): dino_labled_img = base64.b64encode(dino_labled_img).decode("utf-8") elif not isinstance(dino_labled_img, str): raise ValueError("dino_labled_img must be a base64-encoded string or bytes") # Decode the base64 image and re-encode it to ensure consistency image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img))) buffered = io.BytesIO() image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") # Prepare parsed content list parsed_content_list_str = "\n".join(parsed_content_list) return ProcessResponse( image=img_str, parsed_content_list=parsed_content_list_str, label_coordinates=str(label_coordinates), ) except Exception as e: logger.error(f"Error in process function: {e}") raise @app.post("/process_image", response_model=ProcessResponse) async def process_image( image_file: UploadFile = File(...), box_threshold: float = 0.05, iou_threshold: float = 0.1, ): try: contents = await image_file.read() image_input = Image.open(io.BytesIO(contents)).convert("RGB") # Log image details logger.info(f"Processing image: {image_file.filename}") logger.info(f"Image size: {image_input.size}") # Process the image response = process(image_input, box_threshold, iou_threshold) # Validate response if not response.image: raise ValueError("Empty image in response") return response except Exception as e: logger.error(f"Error in process_image endpoint: {e}") raise HTTPException(status_code=500, detail=str(e))