jays009 commited on
Commit
9975291
·
verified ·
1 Parent(s): 6538291

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -28
app.py CHANGED
@@ -5,45 +5,40 @@ from torchvision import models, transforms
5
  from huggingface_hub import hf_hub_download
6
  from PIL import Image
7
 
8
- # Define the number of classes
9
- num_classes = 2 # Update with the actual number of classes in your dataset (e.g., 2 for healthy and anomalous)
10
 
11
- # Download model from Hugging Face
12
  def download_model():
13
- model_path = hf_hub_download(repo_id="jays009/jays009/Restnet50", filename="pytorch_model.bin")
14
  return model_path
15
 
16
- # Load the model from Hugging Face
17
  def load_model(model_path):
18
- model = models.resnet50(pretrained=False) # Set pretrained=False because you're loading custom weights
19
- model.fc = nn.Linear(model.fc.in_features, num_classes) # Adjust for the number of classes in your dataset
20
- model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) # Load model on CPU for compatibility
21
- model.eval() # Set to evaluation mode
22
  return model
23
 
24
- # Download the model and load it
25
- model_path = download_model() # Downloads the model from Hugging Face Hub
26
  model = load_model(model_path)
27
 
28
- # Define the transformation for the input image
29
  transform = transforms.Compose([
30
  transforms.Resize(256), # Resize the image to 256x256
31
  transforms.CenterCrop(224), # Crop the image to 224x224
32
  transforms.ToTensor(), # Convert the image to a Tensor
33
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), # Normalize the image (ImageNet mean and std)
34
  ])
35
 
36
- # Define the prediction function
37
  def predict(image):
38
- # Apply the necessary transformations to the image
39
- image = transform(image).unsqueeze(0) # Add batch dimension
40
- image = image.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) # Move to GPU if available
41
 
42
  with torch.no_grad():
43
- outputs = model(image) # Perform forward pass
44
- predicted_class = torch.argmax(outputs, dim=1).item() # Get the predicted class
 
45
 
46
- # Create a response based on the predicted class
47
  if predicted_class == 0:
48
  return "The photo you've sent is of fall army worm with problem ID 126."
49
  elif predicted_class == 1:
@@ -51,15 +46,13 @@ def predict(image):
51
  else:
52
  return "Unexpected class prediction."
53
 
54
- # Create the Gradio interface
55
  iface = gr.Interface(
56
- fn=predict, # Function for prediction
57
- inputs=gr.Image(type="pil"), # Image input
58
- outputs=gr.Textbox(), # Output: Predicted class
59
- live=True, # Updates as the user uploads an image
60
- title="Wheat Anomaly Detection",
61
- description="Upload an image of wheat to detect anomalies like disease or pest infestation."
62
  )
63
 
64
- # Launch the Gradio interface
65
  iface.launch()
 
5
  from huggingface_hub import hf_hub_download
6
  from PIL import Image
7
 
8
+ num_classes = 2
 
9
 
 
10
  def download_model():
11
+ model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
12
  return model_path
13
 
 
14
  def load_model(model_path):
15
+ model = models.resnet50(pretrained=False)
16
+ model.fc = nn.Linear(model.fc.in_features, num_classes)
17
+ model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
18
+ model.eval()
19
  return model
20
 
21
+ model_path = download_model()
 
22
  model = load_model(model_path)
23
 
24
+
25
  transform = transforms.Compose([
26
  transforms.Resize(256), # Resize the image to 256x256
27
  transforms.CenterCrop(224), # Crop the image to 224x224
28
  transforms.ToTensor(), # Convert the image to a Tensor
29
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
30
  ])
31
 
 
32
  def predict(image):
33
+
34
+ image = transform(image).unsqueeze(0)
35
+ image = image.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
36
 
37
  with torch.no_grad():
38
+ outputs = model(image)
39
+ predicted_class = torch.argmax(outputs, dim=1).item()
40
+
41
 
 
42
  if predicted_class == 0:
43
  return "The photo you've sent is of fall army worm with problem ID 126."
44
  elif predicted_class == 1:
 
46
  else:
47
  return "Unexpected class prediction."
48
 
 
49
  iface = gr.Interface(
50
+ fn=predict,
51
+ inputs=gr.Image(type="pil"),
52
+ outputs=gr.Textbox(),
53
+ live=True,
54
+ title="Maize Anomaly Detection",
55
+ description="Upload an image of maize to detect anomalies like disease or pest infestation."
56
  )
57
 
 
58
  iface.launch()