jays009 commited on
Commit
2b3983d
·
verified ·
1 Parent(s): 73ad4c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -20
app.py CHANGED
@@ -5,40 +5,43 @@ from torchvision import models, transforms
5
  from huggingface_hub import hf_hub_download
6
  from PIL import Image
7
 
8
- num_classes = 2
9
 
 
10
  def download_model():
11
  model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
12
  return model_path
13
 
 
14
  def load_model(model_path):
15
- model = models.resnet50(pretrained=False)
16
- model.fc = nn.Linear(model.fc.in_features, num_classes)
17
- model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
18
- model.eval()
19
  return model
20
 
21
- model_path = download_model()
 
22
  model = load_model(model_path)
23
 
24
-
25
  transform = transforms.Compose([
26
  transforms.Resize(256), # Resize the image to 256x256
27
  transforms.CenterCrop(224), # Crop the image to 224x224
28
  transforms.ToTensor(), # Convert the image to a Tensor
29
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
30
  ])
31
 
 
32
  def predict(image):
33
-
34
- image = transform(image).unsqueeze(0)
35
- image = image.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
36
-
37
  with torch.no_grad():
38
- outputs = model(image)
39
- predicted_class = torch.argmax(outputs, dim=1).item()
40
 
41
-
42
  if predicted_class == 0:
43
  return "The photo you've sent is of fall army worm with problem ID 126."
44
  elif predicted_class == 1:
@@ -46,13 +49,16 @@ def predict(image):
46
  else:
47
  return "Unexpected class prediction."
48
 
 
49
  iface = gr.Interface(
50
- fn=predict,
51
- inputs=gr.Image(type="pil"),
52
- outputs=gr.Textbox(),
53
- live=True,
54
  title="Maize Anomaly Detection",
55
- description="Upload an image of maize to detect anomalies like disease or pest infestation."
 
56
  )
57
 
 
58
  iface.launch()
 
5
  from huggingface_hub import hf_hub_download
6
  from PIL import Image
7
 
8
+ num_classes = 2 # Number of classes for your dataset
9
 
10
+ # Download model weights from Hugging Face
11
  def download_model():
12
  model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
13
  return model_path
14
 
15
+ # Load the model from the downloaded weights
16
  def load_model(model_path):
17
+ model = models.resnet50(pretrained=False) # Set pretrained=False for custom weights
18
+ model.fc = nn.Linear(model.fc.in_features, num_classes) # Adjust final layer for your number of classes
19
+ model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) # Load model weights
20
+ model.eval() # Set model to evaluation mode
21
  return model
22
 
23
+ # Download and load the model
24
+ model_path = download_model()
25
  model = load_model(model_path)
26
 
27
+ # Image transformation pipeline
28
  transform = transforms.Compose([
29
  transforms.Resize(256), # Resize the image to 256x256
30
  transforms.CenterCrop(224), # Crop the image to 224x224
31
  transforms.ToTensor(), # Convert the image to a Tensor
32
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), # Normalize for ImageNet
33
  ])
34
 
35
+ # Prediction function
36
  def predict(image):
37
+ image = transform(image).unsqueeze(0) # Add batch dimension
38
+ image = image.to(torch.device("cpu")) # Move the image to CPU (adjust if you want to use GPU)
39
+
 
40
  with torch.no_grad():
41
+ outputs = model(image) # Perform forward pass
42
+ predicted_class = torch.argmax(outputs, dim=1).item() # Get the predicted class ID
43
 
44
+ # Return appropriate response based on predicted class
45
  if predicted_class == 0:
46
  return "The photo you've sent is of fall army worm with problem ID 126."
47
  elif predicted_class == 1:
 
49
  else:
50
  return "Unexpected class prediction."
51
 
52
+ # Create the Gradio interface and expose it as an API
53
  iface = gr.Interface(
54
+ fn=predict, # Prediction function
55
+ inputs=gr.Image(type="pil"), # Image input (PIL format)
56
+ outputs=gr.Textbox(), # Text output (Predicted class description)
57
+ live=True, # Update predictions as the user uploads an image
58
  title="Maize Anomaly Detection",
59
+ description="Upload an image of maize to detect anomalies like disease or pest infestation.",
60
+ api=True # Expose the Gradio interface for API calls (POST requests)
61
  )
62
 
63
+ # Launch the Gradio interface
64
  iface.launch()