|
import pandas as pd |
|
import numpy as np |
|
import tensorflow as tf |
|
from keras.models import load_model |
|
import gradio as gr |
|
|
|
|
|
classes = [ |
|
'car', 'house', 'wine bottle', 'chair', 'table', |
|
'tree', 'camera', 'fish', 'rain', 'clock', 'hat', |
|
'dog', 'cat', 'bicycle', 'plane', 'book', 'computer' |
|
] |
|
|
|
labels = {name: index for index, name in enumerate(classes)} |
|
num_classes = len(classes) |
|
|
|
|
|
from keras.models import load_model |
|
model = load_model('sketch_recogination_model_cnn.h5') |
|
|
|
def predict_fn(image): |
|
""" |
|
Predict the class of a drawn image. |
|
|
|
Args: |
|
image: The input image drawn by the user. |
|
|
|
Returns: |
|
The predicted class name. |
|
""" |
|
try: |
|
|
|
resized_image = tf.image.resize(image, (28, 28)) |
|
grayscale_image = tf.image.rgb_to_grayscale(resized_image) |
|
image_array = np.array(grayscale_image) / 255.0 |
|
|
|
|
|
image_array = image_array.reshape(1, 28, 28, 1) |
|
predictions = model.predict(image_array).reshape(num_classes) |
|
|
|
|
|
predicted_index = tf.argmax(predictions).numpy() |
|
class_name = classes[predicted_index] |
|
|
|
return class_name |
|
except Exception as e: |
|
return f"Error in prediction: {str(e)}" |
|
|
|
|
|
gr.Interface( |
|
fn=predict_fn, |
|
inputs="paint", |
|
outputs="label", |
|
title="DoodleDecoder", |
|
description="Draw something from: Car, House, Wine bottle, Chair, Table, Tree, Camera, Fish, Rain, Clock, Hat, Dog, Cat, Bicycle, Plane, Book, Computer", |
|
article="Draw large with thick stroke." |
|
).launch() |
|
|