jays009 commited on
Commit
9dfc63c
·
verified ·
1 Parent(s): a62d15d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -10
app.py CHANGED
@@ -1,17 +1,65 @@
1
  import gradio as gr
 
 
 
 
 
2
 
3
- def process_image(image):
4
- # For now, this function will just return a placeholder response
5
- return "Placeholder output for uploaded image"
6
 
7
- # Define the Gradio interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  iface = gr.Interface(
9
- fn=process_image, # Function to process the image
10
- inputs=gr.Image(type="pil", label="Upload Image"), # Accepts image input
11
- outputs=gr.Textbox(label="Prediction"), # Displays text output
12
- title="Image Input Example",
13
- description="Upload an image to get a prediction. Placeholder logic for now.",
 
14
  )
15
 
16
- # Launch the interface
17
  iface.launch()
 
1
  import gradio as gr
2
+ import torch
3
+ from torch import nn
4
+ from torchvision import models, transforms
5
+ from huggingface_hub import hf_hub_download
6
+ from PIL import Image
7
 
8
+ # Define the number of classes
9
+ num_classes = 2 # Update with the actual number of classes in your dataset (e.g., 2 for healthy and anomalous)
 
10
 
11
+ # Download model from Hugging Face
12
+ def download_model():
13
+ model_path = hf_hub_download(repo_id="jays009/jays009/Restnet50", filename="pytorch_model.bin")
14
+ return model_path
15
+
16
+ # Load the model from Hugging Face
17
+ def load_model(model_path):
18
+ model = models.resnet50(pretrained=False) # Set pretrained=False because you're loading custom weights
19
+ model.fc = nn.Linear(model.fc.in_features, num_classes) # Adjust for the number of classes in your dataset
20
+ model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) # Load model on CPU for compatibility
21
+ model.eval() # Set to evaluation mode
22
+ return model
23
+
24
+ # Download the model and load it
25
+ model_path = download_model() # Downloads the model from Hugging Face Hub
26
+ model = load_model(model_path)
27
+
28
+ # Define the transformation for the input image
29
+ transform = transforms.Compose([
30
+ transforms.Resize(256), # Resize the image to 256x256
31
+ transforms.CenterCrop(224), # Crop the image to 224x224
32
+ transforms.ToTensor(), # Convert the image to a Tensor
33
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), # Normalize the image (ImageNet mean and std)
34
+ ])
35
+
36
+ # Define the prediction function
37
+ def predict(image):
38
+ # Apply the necessary transformations to the image
39
+ image = transform(image).unsqueeze(0) # Add batch dimension
40
+ image = image.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) # Move to GPU if available
41
+
42
+ with torch.no_grad():
43
+ outputs = model(image) # Perform forward pass
44
+ predicted_class = torch.argmax(outputs, dim=1).item() # Get the predicted class
45
+
46
+ # Create a response based on the predicted class
47
+ if predicted_class == 0:
48
+ return "The photo you've sent is of fall army worm with problem ID 126."
49
+ elif predicted_class == 1:
50
+ return "The photo you've sent is of a healthy wheat image."
51
+ else:
52
+ return "Unexpected class prediction."
53
+
54
+ # Create the Gradio interface
55
  iface = gr.Interface(
56
+ fn=predict, # Function for prediction
57
+ inputs=gr.Image(type="pil"), # Image input
58
+ outputs=gr.Textbox(), # Output: Predicted class
59
+ live=True, # Updates as the user uploads an image
60
+ title="Wheat Anomaly Detection",
61
+ description="Upload an image of wheat to detect anomalies like disease or pest infestation."
62
  )
63
 
64
+ # Launch the Gradio interface
65
  iface.launch()