|
import torch |
|
from torchvision import models, transforms |
|
from PIL import Image |
|
import gradio as gr |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
def load_model(): |
|
model = models.resnet50(pretrained=False) |
|
num_classes = 4 |
|
model.fc = torch.nn.Sequential( |
|
torch.nn.Linear(model.fc.in_features, 256), |
|
torch.nn.ReLU(), |
|
torch.nn.Linear(256, num_classes) |
|
) |
|
model.load_state_dict(torch.load("best_model_epoch_43.pth", map_location=device), strict=False) |
|
model = model.to(device) |
|
model.eval() |
|
return model |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
|
|
label_colors = { |
|
"Brown Spot": "#b2ff00", |
|
"Healthy": "#2ecc71", |
|
"Leaf Blast": "#ff00d4", |
|
"Neck Blast": "#ffd100" |
|
} |
|
|
|
|
|
def get_confidence_color(confidence): |
|
if confidence < 0.25: |
|
return "#e74c3c" |
|
elif confidence < 0.50: |
|
return "#f39c12" |
|
elif confidence < 0.75: |
|
return "#00b9ff" |
|
else: |
|
return "#13ff00" |
|
|
|
|
|
def predict(image): |
|
|
|
image = image.convert("RGB") |
|
|
|
input_tensor = transform(image).unsqueeze(0).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(input_tensor) |
|
_, predicted_class = torch.max(outputs, 1) |
|
|
|
|
|
class_names = ["Brown Spot", "Healthy", "Leaf Blast", "Neck Blast"] |
|
predicted_label = class_names[predicted_class.item()] |
|
|
|
|
|
probabilities = torch.nn.functional.softmax(outputs, dim=1)[0] |
|
confidence = probabilities[predicted_class.item()].item() |
|
|
|
|
|
label_color = label_colors.get(predicted_label, "#FFFFFF") |
|
confidence_color = get_confidence_color(confidence) |
|
|
|
result = f"<div style='color:{label_color}; font-size:30px; font-weight:bold;'>{predicted_label}</div>" |
|
result += f"<div style='color:{confidence_color}; font-size:25px; font-weight:bold;'>Confidence: {confidence*100:.2f}%</div>" |
|
return result |
|
|
|
|
|
def launch_interface(): |
|
|
|
iface = gr.Interface( |
|
theme=gr.themes.Citrus( |
|
primary_hue="emerald", |
|
neutral_hue="slate" |
|
), |
|
fn=predict, |
|
inputs=gr.Image(type="pil", label="Upload Rice Leaf Image"), |
|
outputs=gr.HTML(label="Prediction Results"), |
|
title="<span style='color: #00fff7; font-size:40px; font-weight: bold;'>Rice Disease Classification</span>", |
|
description="<span style='color: lightblue; font-size:26px;'>Upload a rice leaf image to detect its condition (Brown Spot, Healthy, Leaf Blast, or Neck Blast)</span>", |
|
examples=[ |
|
["https://doa.gov.lk/wp-content/uploads/2020/06/brownspot3-1024x683.jpg"], |
|
["https://arkansascrops.uada.edu/posts/crops/rice/images/Fig%206%20Rice%20leaf%20blast%20coalesced%20lesions.png"], |
|
["https://th.bing.com/th/id/OIP._5ejX_5Z-M0cO5c2QUmPlwHaE7?w=280&h=187&c=7&r=0&o=5&dpr=1.1&pid=1.7"] |
|
], |
|
allow_flagging="never" |
|
) |
|
|
|
return iface |
|
|
|
|
|
model = load_model() |
|
|
|
|
|
if __name__ == "__main__": |
|
interface = launch_interface() |
|
interface.launch(share=True) |
|
|