Spaces:
Sleeping
Sleeping
import pandas as pd | |
import numpy as np | |
import tensorflow as tf | |
from keras.models import load_model | |
import gradio as gr | |
# Extended classes and labels | |
classes = [ | |
'car', 'house', 'wine bottle', 'chair', 'table', | |
'tree', 'camera', 'fish', 'rain', 'clock', 'hat', | |
'dog', 'cat', 'bicycle', 'plane', 'book', 'computer' | |
] | |
labels = {name: index for index, name in enumerate(classes)} | |
num_classes = len(classes) | |
# Load the model | |
model = load_model('sketch_recognition_model_cnn.h5') | |
# Predict function for interface | |
def predict_fn(image): | |
""" | |
Predict the class of a drawn image. | |
Args: | |
image: The input image drawn by the user. | |
Returns: | |
The predicted class name. | |
""" | |
try: | |
# Preprocessing the image | |
resized_image = tf.image.resize(image, (28, 28)) # Resize image to (28, 28) | |
grayscale_image = tf.image.rgb_to_grayscale(resized_image) # Convert image to grayscale | |
image_array = np.array(grayscale_image) / 255.0 # Normalize the image | |
# Prepare image for model input | |
image_array = image_array.reshape(1, 28, 28, 1) # Add batch dimension | |
predictions = model.predict(image_array).reshape(num_classes) # 2D output to 1D | |
# Predict the class index | |
predicted_index = tf.argmax(predictions).numpy() # Get the index of the highest score | |
class_name = classes[predicted_index] # Retrieve the class name | |
return class_name | |
except Exception as e: | |
return f"Error in prediction: {str(e)}" | |
# Gradio application interface | |
gr.Interface( | |
fn=predict_fn, | |
inputs="paint", | |
outputs="label", | |
title="DoodleDecoder", | |
description="Draw something from: Car, House, Wine bottle, Chair, Table, Tree, Camera, Fish, Rain, Clock, Hat, Dog, Cat, Bicycle, Plane, Book, Computer", | |
interpretation='default', | |
article="Draw large with thick stroke." | |
).launch() | |