hb-setosys commited on
Commit
c28dbdc
·
verified ·
1 Parent(s): 643a4b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -42
app.py CHANGED
@@ -1,57 +1,56 @@
 
 
1
  import gradio as gr
2
  import tensorflow as tf
3
- from tensorflow.keras.applications.resnet import ResNet152, preprocess_input, decode_predictions
4
- from tensorflow.keras.preprocessing.image import img_to_array
5
- from PIL import Image
6
  import numpy as np
 
 
 
 
 
7
 
8
- # Load the pre-trained ResNet152 model
9
- MODEL_PATH = "resnet152-image-classifier.h5" # Path to the saved model
10
- try:
11
- model = tf.keras.models.load_model(MODEL_PATH)
12
- except Exception as e:
13
- print(f"Error loading model: {e}")
14
- exit()
15
-
16
- def decode_image_from_base64(base64_str):
17
- # Decode the base64 string to bytes
18
- image_data = base64.b64decode(base64_str)
19
- # Convert the bytes into a PIL image
20
- image = Image.open(BytesIO(image_data))
21
  return image
22
-
23
- def predict_image(image):
 
24
  """
25
- Process the uploaded image and return the top 3 predictions.
26
  """
27
- try:
28
- # Preprocess the image
29
- image = image.resize((224, 224)) # ResNet152 expects 224x224 input
30
- image_array = img_to_array(image)
31
- image_array = preprocess_input(image_array) # Normalize the image
32
- image_array = np.expand_dims(image_array, axis=0) # Add batch dimension
33
 
34
- # Get predictions
35
- predictions = model.predict(image_array)
36
- decoded_predictions = decode_predictions(predictions, top=3)[0]
37
 
38
- # Format predictions as a list of tuples (label, confidence)
39
- results = [(label, float(confidence)) for _, label, confidence in decoded_predictions]
40
- return dict(results)
41
 
42
- except Exception as e:
43
- return {"Error": str(e)}
44
 
45
- # Create the Gradio interface
46
  interface = gr.Interface(
47
- fn=predict_image,
48
- inputs=gr.Image(type="pil"), # Accepts an image input
49
- outputs=gr.Label(num_top_classes=3), # Shows top 3 predictions with confidence
50
- title="ResNet152 Image Classifier",
51
- description="Upload an image, and the model will predict what's in the image.",
52
- examples=["dog.jpg", "cat.jpg"], # Example images for users to test
 
 
 
53
  )
54
 
55
- # Launch the Gradio app
 
56
  if __name__ == "__main__":
57
- interface.launch()
 
1
+ #denis_mnist_cnn_model_resnet50_v1.h5") # Ensure you upload this file to Hugging Face Spaces
2
+
3
  import gradio as gr
4
  import tensorflow as tf
 
 
 
5
  import numpy as np
6
+ from tensorflow.keras.applications.resnet50 import preprocess_input
7
+ from tensorflow.keras.utils import load_img, img_to_array
8
+
9
+ # Load your trained model
10
+ model = tf.keras.models.load_model("denis_mnist_cnn_model_resnet50_v1.h5") # Ensure you upload this file to Hugging Face Spaces
11
 
12
+ # Define a function to preprocess the image
13
+ def preprocess_image(image):
14
+ """
15
+ Preprocesses the uploaded image for prediction.
16
+ """
17
+ image = image.resize((128, 128)) # Resize to match the model input size
18
+ image = img_to_array(image) # Convert PIL image to NumPy array
19
+ image = preprocess_input(image) # Normalize for ResNet50
20
+ image = np.expand_dims(image, axis=0) # Add batch dimension
 
 
 
 
21
  return image
22
+
23
+ # Define the prediction function
24
+ def predict(image):
25
  """
26
+ Accepts an image, preprocesses it, and returns the predicted label.
27
  """
28
+ processed_image = preprocess_image(image)
29
+ predictions = model.predict(processed_image)
30
+ predicted_class = np.argmax(predictions, axis=-1)[0] # Get the class index
31
+ confidence = np.max(predictions) # Get confidence score
32
+ #return f"Predicted Class: {predicted_class}, Confidence: {confidence:.2f}"
33
+ return {"prediction": int(predicted_class)}
34
 
 
 
 
35
 
36
+ # Create a Gradio interface
37
+ #interface = gr.Interface(fn=predict, inputs="image", outputs="json")
 
38
 
 
 
39
 
40
+ # Create a Gradio interface
41
  interface = gr.Interface(
42
+ fn=predict, # The prediction function
43
+ inputs=gr.Image(type="pil", label="Upload an Image"), # Input: Image
44
+ outputs=gr.Textbox(label="Prediction"), # Output: Textbox
45
+ title="MNIST ResNet50 Classifier",
46
+ description="Upload an image to classify it using the trained ResNet50 model.",
47
+ examples=[
48
+ ["example_images/example1.png"], # Add paths to example images in your Hugging Face repository
49
+ ["example_images/example2.png"]
50
+ ],
51
  )
52
 
53
+
54
+ # Launch the app
55
  if __name__ == "__main__":
56
+ interface.launch(share=True)