File size: 3,654 Bytes
cbc5566
9dfc63c
 
 
 
 
bf44ad8
 
 
cbc5566
38d7439
 
cbc5566
38d7439
9dfc63c
eff8876
9dfc63c
 
38d7439
9dfc63c
38d7439
 
 
 
9dfc63c
 
38d7439
 
9dfc63c
 
38d7439
9dfc63c
 
 
 
38d7439
9dfc63c
 
bf44ad8
 
 
 
 
 
 
 
 
 
 
38d7439
bf44ad8
 
 
 
 
 
 
 
 
 
 
 
38d7439
2b3983d
38d7439
 
9dfc63c
2b3983d
38d7439
 
 
9dfc63c
 
 
 
 
 
 
fb31436
a62d15d
38d7439
 
 
 
bf44ad8
 
a62d15d
 
ce20917
61dc7b2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio as gr
import torch
from torch import nn
from torchvision import models, transforms
from huggingface_hub import hf_hub_download
from PIL import Image
import requests
import base64
from io import BytesIO

# Define the number of classes
num_classes = 2  # Update with the actual number of classes in your dataset (e.g., 2 for healthy and anomalous)

# Download model from Hugging Face
def download_model():
    model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
    return model_path

# Load the model from Hugging Face
def load_model(model_path):
    model = models.resnet50(pretrained=False)  # Set pretrained=False because you're loading custom weights
    model.fc = nn.Linear(model.fc.in_features, num_classes)  # Adjust for the number of classes in your dataset
    model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))  # Load model on CPU for compatibility
    model.eval()  # Set to evaluation mode
    return model

# Download the model and load it
model_path = download_model()  # Downloads the model from Hugging Face Hub
model = load_model(model_path)

# Define the transformation for the input image
transform = transforms.Compose([
    transforms.Resize(256),  # Resize the image to 256x256
    transforms.CenterCrop(224),  # Crop the image to 224x224
    transforms.ToTensor(),  # Convert the image to a Tensor
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),  # Normalize the image (ImageNet mean and std)
])

# Function to convert image from URL to PIL image
def url_to_image(image_url):
    response = requests.get(image_url)
    img = Image.open(BytesIO(response.content))
    return img

# Function to convert base64 string to PIL image
def base64_to_pil(base64_string):
    img_data = base64.b64decode(base64_string)
    return Image.open(BytesIO(img_data))

# Define the prediction function
def predict(image_input):
    # If input is a string (URL or base64 encoded), handle accordingly
    if isinstance(image_input, str):
        if image_input.startswith("http"):  # If URL
            image = url_to_image(image_input)
        elif image_input.startswith("data:image"):  # If base64 string
            image = base64_to_pil(image_input)
        else:  # Local image path
            image = Image.open(image_input)
    else:
        image = image_input  # If the input is already a PIL image

    # Apply the necessary transformations to the image
    image = transform(image).unsqueeze(0)  # Add batch dimension
    image = image.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))  # Move to GPU if available
    
    with torch.no_grad():
        outputs = model(image)  # Perform forward pass
        predicted_class = torch.argmax(outputs, dim=1).item()  # Get the predicted class
    
    # Create a response based on the predicted class
    if predicted_class == 0:
        return "The photo you've sent is of fall army worm with problem ID 126."
    elif predicted_class == 1:
        return "The photo you've sent is of a healthy wheat image."
    else:
        return "Unexpected class prediction."

# Create the Gradio interface
iface = gr.Interface(
    fn=predict,  # Function for prediction
    inputs=gr.Image(type="pil"),  # Image input
    outputs=gr.Textbox(),  # Output: Predicted class
    live=True,  # Updates as the user uploads an image
    title="Maize Anomaly Detection",
    description="Upload an image of maize to detect anomalies like disease or pest infestation. You can provide local paths, URLs, or base64-encoded images."
)

# Launch the Gradio interface
iface.launch(share=True)