Spaces:
Running
Running
import gradio as gr | |
import json | |
import torch | |
from torch import nn | |
from torchvision import models, transforms | |
from huggingface_hub import hf_hub_download | |
from PIL import Image | |
import requests | |
from io import BytesIO | |
from fastapi import FastAPI | |
from gradio.routes import App | |
# Define the number of classes | |
num_classes = 2 | |
# In-memory storage for results | |
results_cache = {} | |
# Download model from Hugging Face | |
def download_model(): | |
model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin") | |
return model_path | |
# Load the model from Hugging Face | |
def load_model(model_path): | |
model = models.resnet50(pretrained=False) | |
model.fc = nn.Linear(model.fc.in_features, num_classes) | |
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) | |
model.eval() | |
return model | |
# Download the model and load it | |
model_path = download_model() | |
model = load_model(model_path) | |
# Define the transformation for the input image | |
transform = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
]) | |
# Function to predict from image content | |
def predict_from_image(image): | |
try: | |
# Log the image processing | |
print(f"Processing image: {image}") | |
# Ensure the image is a PIL Image | |
if not isinstance(image, Image.Image): | |
raise ValueError("Invalid image format received. Please provide a valid image.") | |
# Apply transformations | |
image_tensor = transform(image).unsqueeze(0) | |
# Predict | |
with torch.no_grad(): | |
outputs = model(image_tensor) | |
predicted_class = torch.argmax(outputs, dim=1).item() | |
# Interpret the result | |
if predicted_class == 0: | |
return {"result": "The photo is of fall army worm with problem ID 126."} | |
elif predicted_class == 1: | |
return {"result": "The photo is of a healthy maize image."} | |
else: | |
return {"error": "Unexpected class prediction."} | |
except Exception as e: | |
print(f"Error during image processing: {e}") | |
return {"error": str(e)} | |
# Function to predict from URL | |
def predict_from_url(url): | |
try: | |
# Fetch the image from the URL | |
response = requests.get(url) | |
response.raise_for_status() # Ensure the request was successful | |
image = Image.open(BytesIO(response.content)) | |
print(f"Fetched image from URL: {url}") | |
return predict_from_image(image) | |
except Exception as e: | |
print(f"Error during URL processing: {e}") | |
return {"error": f"Failed to process the URL: {str(e)}"} | |
# Main prediction function with caching | |
def predict(image, url): | |
try: | |
if image: | |
result = predict_from_image(image) | |
elif url: | |
result = predict_from_url(url) | |
else: | |
result = {"error": "No input provided. Please upload an image or provide a URL."} | |
# Generate and store the event ID | |
event_id = id(result) # Use Python's id() function to generate a unique identifier | |
results_cache[event_id] = result | |
# Log the result | |
print(f"Event ID: {event_id}, Result: {result}") | |
return {"event_id": event_id, "result": result} | |
except Exception as e: | |
print(f"Error in prediction function: {e}") | |
return {"error": str(e)} | |
# Function to retrieve result by event_id | |
def get_result(event_id): | |
try: | |
# Convert event_id from string to int | |
event_id = int(event_id) | |
result = results_cache.get(event_id) | |
if result: | |
return result | |
else: | |
return {"error": "No result found for the provided event ID."} | |
except Exception as e: | |
return {"error": f"Invalid event ID: {str(e)}"} | |
# Create a FastAPI app for handling the GET request | |
app = FastAPI() | |
def get_result_api(event_id: int): | |
return get_result(event_id) | |
# Gradio interface setup | |
iface = gr.Blocks() | |
with iface: | |
gr.Markdown("# Maize Anomaly Detection") | |
with gr.Row(): | |
image_input = gr.Image(type="pil", label="Upload an Image") | |
url_input = gr.Textbox(label="Or Enter an Image URL", placeholder="Provide a valid image URL") | |
output = gr.JSON(label="Prediction Result") | |
submit_button = gr.Button("Submit") | |
submit_button.click( | |
fn=predict, | |
inputs=[image_input, url_input], | |
outputs=output | |
) | |
# Event ID retrieval section | |
with gr.Row(): | |
event_id_input = gr.Textbox(label="Event ID", placeholder="Enter Event ID") | |
event_output = gr.JSON(label="Retrieved Result") | |
retrieve_button = gr.Button("Get Result") | |
retrieve_button.click( | |
fn=get_result, | |
inputs=[event_id_input], | |
outputs=event_output | |
) | |
# Launch the Gradio interface | |
iface.launch(share=True, show_error=True, server_name="0.0.0.0", server_port=7860) | |