hb-setosys commited on
Commit
643a4b4
·
verified ·
1 Parent(s): e16a334

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -30
app.py CHANGED
@@ -1,53 +1,57 @@
1
  import gradio as gr
2
  import tensorflow as tf
3
- import numpy as np
4
- from tensorflow.keras.applications.resnet152 import ResNet152, preprocess_input, decode_predictions
5
  from tensorflow.keras.preprocessing.image import img_to_array
6
  from PIL import Image
 
7
 
8
  # Load the pre-trained ResNet152 model
9
- MODEL_PATH = "resnet152-image-classifier.h5" # Path to your model file
10
  try:
11
  model = tf.keras.models.load_model(MODEL_PATH)
12
  except Exception as e:
13
  print(f"Error loading model: {e}")
14
  exit()
15
 
16
- def preprocess_image(image):
17
- """
18
- Preprocesses the uploaded image for prediction.
19
- """
20
- image = image.resize((224, 224)) # Resize to match the model input size (224x224 for ResNet152)
21
- image = img_to_array(image) # Convert PIL image to NumPy array
22
- image = preprocess_input(image) # Normalize for ResNet152
23
- image = np.expand_dims(image, axis=0) # Add batch dimension
24
  return image
25
-
26
- def predict(image):
27
  """
28
- Accepts an image, preprocesses it, and returns the top 3 predictions.
29
  """
30
- processed_image = preprocess_image(image)
31
- predictions = model.predict(processed_image)
32
- decoded_predictions = decode_predictions(predictions, top=3)[0]
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- # Return predictions as a list of dictionaries
35
- results = [{"label": label, "confidence": float(confidence)} for _, label, confidence in decoded_predictions]
36
- return {"predictions": results}
37
 
38
- # Create a Gradio interface
39
  interface = gr.Interface(
40
- fn=predict, # The prediction function
41
- inputs=gr.Image(type="pil", label="Upload an Image"), # Input: Image
42
- outputs=gr.JSON(label="Predictions"), # Output: JSON (Top 3 Predictions)
43
  title="ResNet152 Image Classifier",
44
  description="Upload an image, and the model will predict what's in the image.",
45
- examples=[
46
- ["example_images/example1.jpg"], # Add paths to example images
47
- ["example_images/example2.jpg"]
48
- ],
49
  )
50
 
51
- # Launch the app
52
  if __name__ == "__main__":
53
- interface.launch(share=True)
 
1
  import gradio as gr
2
  import tensorflow as tf
3
+ from tensorflow.keras.applications.resnet import ResNet152, preprocess_input, decode_predictions
 
4
  from tensorflow.keras.preprocessing.image import img_to_array
5
  from PIL import Image
6
+ import numpy as np
7
 
8
  # Load the pre-trained ResNet152 model
9
+ MODEL_PATH = "resnet152-image-classifier.h5" # Path to the saved model
10
  try:
11
  model = tf.keras.models.load_model(MODEL_PATH)
12
  except Exception as e:
13
  print(f"Error loading model: {e}")
14
  exit()
15
 
16
+ def decode_image_from_base64(base64_str):
17
+ # Decode the base64 string to bytes
18
+ image_data = base64.b64decode(base64_str)
19
+ # Convert the bytes into a PIL image
20
+ image = Image.open(BytesIO(image_data))
 
 
 
21
  return image
22
+
23
+ def predict_image(image):
24
  """
25
+ Process the uploaded image and return the top 3 predictions.
26
  """
27
+ try:
28
+ # Preprocess the image
29
+ image = image.resize((224, 224)) # ResNet152 expects 224x224 input
30
+ image_array = img_to_array(image)
31
+ image_array = preprocess_input(image_array) # Normalize the image
32
+ image_array = np.expand_dims(image_array, axis=0) # Add batch dimension
33
+
34
+ # Get predictions
35
+ predictions = model.predict(image_array)
36
+ decoded_predictions = decode_predictions(predictions, top=3)[0]
37
+
38
+ # Format predictions as a list of tuples (label, confidence)
39
+ results = [(label, float(confidence)) for _, label, confidence in decoded_predictions]
40
+ return dict(results)
41
 
42
+ except Exception as e:
43
+ return {"Error": str(e)}
 
44
 
45
+ # Create the Gradio interface
46
  interface = gr.Interface(
47
+ fn=predict_image,
48
+ inputs=gr.Image(type="pil"), # Accepts an image input
49
+ outputs=gr.Label(num_top_classes=3), # Shows top 3 predictions with confidence
50
  title="ResNet152 Image Classifier",
51
  description="Upload an image, and the model will predict what's in the image.",
52
+ examples=["dog.jpg", "cat.jpg"], # Example images for users to test
 
 
 
53
  )
54
 
55
+ # Launch the Gradio app
56
  if __name__ == "__main__":
57
+ interface.launch()