Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from huggingface_hub import hf_hub_download | |
from torchvision import transforms | |
from PIL import Image | |
import requests | |
import os | |
# URL del modelo en Hugging Face | |
model_url = "https://huggingface.co/macapa/blindness_clas/resolve/main/blindness_model.pth" | |
model_path = "best_model_resnet18.pth" | |
hf_hub_download( | |
repo_id='macapa/blindness_clas', | |
filename='best_model_resnet18.pth', | |
local_dir='.' | |
) | |
# Cargar el modelo PyTorch | |
model = torch.load(model_path, map_location=torch.device('cpu')) | |
# model.eval() | |
# Definir las transformaciones de la imagen | |
preprocess = transforms.Compose([ | |
transforms.Resize((256, 256)), | |
transforms.ToTensor(), | |
]) | |
# Definir las etiquetas de clasificaci贸n | |
labels = ["No Blindness", "Mild", "Moderate", "Severe", "Proliferative"] | |
# Funci贸n para predecir la clase de ceguera | |
def classify_image(img): | |
img = preprocess(img).unsqueeze(0) | |
with torch.no_grad(): | |
outputs = model(img) | |
_, predicted = torch.max(outputs, 1) | |
return labels[predicted.item()] | |
# Definir la interfaz de Gradio | |
interface = gr.Interface( | |
fn=classify_image, | |
inputs=gr.Image(label="Carga una imagen aqu铆"), | |
outputs=gr.Label(num_top_classes=1), | |
title="Blindness Classification", | |
description="Classify the severity of blindness from retinal images." | |
) | |
# Ejecutar la aplicaci贸n | |
interface.launch(share=True) | |