jays009's picture
Update app.py
dd0dfd1 verified
import gradio as gr
import json
import torch
from torch import nn
from torchvision import models, transforms
from huggingface_hub import hf_hub_download
from PIL import Image
import requests
from io import BytesIO
from fastapi import FastAPI
from gradio.routes import App
# Define the number of classes
num_classes = 2
# In-memory storage for results
results_cache = {}
# Download model from Hugging Face
def download_model():
model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
return model_path
# Load the model from Hugging Face
def load_model(model_path):
model = models.resnet50(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
model.eval()
return model
# Download the model and load it
model_path = download_model()
model = load_model(model_path)
# Define the transformation for the input image
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
# Function to predict from image content
def predict_from_image(image):
try:
# Log the image processing
print(f"Processing image: {image}")
# Ensure the image is a PIL Image
if not isinstance(image, Image.Image):
raise ValueError("Invalid image format received. Please provide a valid image.")
# Apply transformations
image_tensor = transform(image).unsqueeze(0)
# Predict
with torch.no_grad():
outputs = model(image_tensor)
predicted_class = torch.argmax(outputs, dim=1).item()
# Interpret the result
if predicted_class == 0:
return {"result": "The photo is of fall army worm with problem ID 126."}
elif predicted_class == 1:
return {"result": "The photo is of a healthy maize image."}
else:
return {"error": "Unexpected class prediction."}
except Exception as e:
print(f"Error during image processing: {e}")
return {"error": str(e)}
# Function to predict from URL
def predict_from_url(url):
try:
# Fetch the image from the URL
response = requests.get(url)
response.raise_for_status() # Ensure the request was successful
image = Image.open(BytesIO(response.content))
print(f"Fetched image from URL: {url}")
return predict_from_image(image)
except Exception as e:
print(f"Error during URL processing: {e}")
return {"error": f"Failed to process the URL: {str(e)}"}
# Main prediction function with caching
def predict(image, url):
try:
if image:
result = predict_from_image(image)
elif url:
result = predict_from_url(url)
else:
result = {"error": "No input provided. Please upload an image or provide a URL."}
# Generate and store the event ID
event_id = id(result) # Use Python's id() function to generate a unique identifier
results_cache[event_id] = result
# Log the result
print(f"Event ID: {event_id}, Result: {result}")
return {"event_id": event_id, "result": result}
except Exception as e:
print(f"Error in prediction function: {e}")
return {"error": str(e)}
# Function to retrieve result by event_id
def get_result(event_id):
try:
# Convert event_id from string to int
event_id = int(event_id)
result = results_cache.get(event_id)
if result:
return result
else:
return {"error": "No result found for the provided event ID."}
except Exception as e:
return {"error": f"Invalid event ID: {str(e)}"}
# Create a FastAPI app for handling the GET request
app = FastAPI()
@app.get("/result/{event_id}")
def get_result_api(event_id: int):
return get_result(event_id)
# Gradio interface setup
iface = gr.Blocks()
with iface:
gr.Markdown("# Maize Anomaly Detection")
with gr.Row():
image_input = gr.Image(type="pil", label="Upload an Image")
url_input = gr.Textbox(label="Or Enter an Image URL", placeholder="Provide a valid image URL")
output = gr.JSON(label="Prediction Result")
submit_button = gr.Button("Submit")
submit_button.click(
fn=predict,
inputs=[image_input, url_input],
outputs=output
)
# Event ID retrieval section
with gr.Row():
event_id_input = gr.Textbox(label="Event ID", placeholder="Enter Event ID")
event_output = gr.JSON(label="Retrieved Result")
retrieve_button = gr.Button("Get Result")
retrieve_button.click(
fn=get_result,
inputs=[event_id_input],
outputs=event_output
)
# Launch the Gradio interface
iface.launch(share=True, show_error=True, server_name="0.0.0.0", server_port=7860)