from fastapi import FastAPI, File, UploadFile, HTTPException from pydantic import BaseModel import base64 import io import os import logging import gc # Import garbage collector# from PIL import Image, UnidentifiedImageError import torch import asyncio from utils import ( check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img, ) from transformers import AutoProcessor, AutoModelForCausalLM # Configure logging logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) # Load YOLO model yolo_model = get_yolo_model(model_path="weights/best.pt") # Handle device placement device = torch.device("cuda" if torch.cuda.is_available() else "cpu") yolo_model = yolo_model.to(device) # Load caption model and processor try: processor = AutoProcessor.from_pretrained( "microsoft/Florence-2-base", trust_remote_code=True ) model = AutoModelForCausalLM.from_pretrained( "weights/icon_caption_florence", torch_dtype=torch.float16, trust_remote_code=True, ).to(device) except Exception as e: logger.warning(f"Failed to load caption model on GPU: {e}. Falling back to CPU.") model = AutoModelForCausalLM.from_pretrained( "weights/icon_caption_florence", torch_dtype=torch.float16, trust_remote_code=True, ) caption_model_processor = {"processor": processor, "model": model} logger.info("Finished loading models!") # Initialize FastAPI app app = FastAPI() MAX_QUEUE_SIZE = 10 # Set a reasonable limit based on your system capacity request_queue = asyncio.Queue(maxsize=MAX_QUEUE_SIZE) # Define response model class ProcessResponse(BaseModel): image: str # Base64 encoded image parsed_content_list: str label_coordinates: str # Background worker to process queue tasks async def worker(): while True: task = await request_queue.get() try: await task except Exception as e: logger.error(f"Error while processing task: {e}") finally: request_queue.task_done() # Start worker on startup @app.on_event("startup") async def startup_event(): logger.info("Starting background worker...") asyncio.create_task(worker()) # Image processing function with memory cleanup async def process(image_input: Image.Image, box_threshold: float, iou_threshold: float) -> ProcessResponse: try: # Define save path image_save_path = "imgs/saved_image_demo.png" os.makedirs(os.path.dirname(image_save_path), exist_ok=True) # Save image image_input.save(image_save_path) logger.debug(f"Image saved to: {image_save_path}") # YOLO and caption model inference box_overlay_ratio = image_input.size[0] / 3200 draw_bbox_config = { "text_scale": 0.8 * box_overlay_ratio, "text_thickness": max(int(2 * box_overlay_ratio), 1), "text_padding": max(int(3 * box_overlay_ratio), 1), "thickness": max(int(3 * box_overlay_ratio), 1), } ocr_bbox_rslt, is_goal_filtered = await asyncio.to_thread( check_ocr_box, image_save_path, display_img=False, output_bb_format="xyxy", goal_filtering=None, easyocr_args={"paragraph": False, "text_threshold": 0.9}, use_paddleocr=True, ) text, ocr_bbox = ocr_bbox_rslt dino_labled_img, label_coordinates, parsed_content_list = await asyncio.to_thread( get_som_labeled_img, image_save_path, yolo_model, BOX_TRESHOLD=box_threshold, output_coord_in_ratio=True, ocr_bbox=ocr_bbox, draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text, iou_threshold=iou_threshold, ) # Convert labeled image to base64 image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img))) buffered = io.BytesIO() image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") # Join parsed content list parsed_content_list_str = "\n".join([str(item) for item in parsed_content_list]) response = ProcessResponse( image=img_str, parsed_content_list=parsed_content_list_str, label_coordinates=str(label_coordinates), ) # **Memory Cleanup** del image_input, text, ocr_bbox, dino_labled_img, label_coordinates, parsed_content_list torch.cuda.empty_cache() # Free GPU memory gc.collect() # Free CPU memory return response except Exception as e: logger.error(f"Error in process function: {e}") raise HTTPException(status_code=500, detail=f"Failed to process the image: {e}") # API endpoint for processing images @app.post("/process_image", response_model=ProcessResponse) async def process_image( image_file: UploadFile = File(...), box_threshold: float = 0.05, iou_threshold: float = 0.1, ): try: # Read image file contents = await image_file.read() try: image_input = Image.open(io.BytesIO(contents)).convert("RGB") except UnidentifiedImageError: logger.error("Unsupported image format.") raise HTTPException(status_code=400, detail="Unsupported image format.") # Create processing task task = asyncio.create_task(process(image_input, box_threshold, iou_threshold)) # Add task to queue await request_queue.put(task) logger.info(f"Task added to queue. Current queue size: {request_queue.qsize()}") # Wait for processing to complete response = await task return response except HTTPException as he: raise he except Exception as e: logger.error(f"Error processing image: {e}") raise HTTPException(status_code=500, detail=f"Internal server error: {e}")