import diffusers import torch from fastapi import FastAPI, UploadFile, HTTPException, File from fastapi.responses import StreamingResponse from PIL import Image import io app = FastAPI() # Inicializa el pipeline al arrancar el servidor @app.on_event("startup") async def startup_event(): global pipe print("[DEBUG] Cargando modelo Marigold...") pipe = diffusers.MarigoldDepthPipeline.from_pretrained( "prs-eth/marigold-depth-lcm-v1-0", variant="fp16", torch_dtype=torch.float16 ).to("cuda") print("[DEBUG] Modelo Marigold cargado exitosamente.") @app.post("/predict-depth/") async def predict_depth(file: UploadFile = File(...)): try: # Verifica si el archivo es una imagen válida if not file.content_type.startswith("image/"): raise HTTPException(status_code=400, detail="El archivo subido no es una imagen.") # Carga la imagen desde el archivo subido image = Image.open(file.file).convert("RGB") # Realiza la predicción de profundidad print("[DEBUG] Realizando predicción de profundidad...") depth = pipe(image) # Exporta la profundidad como una imagen 16-bit PNG depth_16bit = pipe.image_processor.export_depth_to_16bit_png(depth.prediction) # Guarda la imagen generada en un buffer img_buffer = io.BytesIO() depth_16bit[0].save(img_buffer, format="PNG") img_buffer.seek(0) # Devuelve la imagen como respuesta return StreamingResponse(img_buffer, media_type="image/png") except Exception as e: print(f"[ERROR] {str(e)}") raise HTTPException(status_code=500, detail="Error procesando la imagen.") @app.get("/") async def root(): return {"message": "API de generación de mapas de profundidad con Marigold"}