Kyle Dampier
changed drawing size and emoji preview
1a84122
raw
history blame
1.94 kB
import tensorflow as tf
import matplotlib.pyplot as plt
from PIL import Image, ImageOps
from tensorflow.keras.utils import img_to_array
from streamlit_drawable_canvas import st_canvas
import streamlit as st
# st.set_page_config(layout="wide")
st.write('# MNIST Digit Recognition')
st.write('## Using a CNN `Keras` model')
# Import Pre-trained Model
model = tf.keras.models.load_model('mnist.h5')
plt.rcParams.update({'font.size': 18})
# Create a sidebar to hold the settings
stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 9)
realtime_update = st.sidebar.checkbox("Update in realtime", True)
canvas_result = st_canvas(
fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
stroke_width=stroke_width,
stroke_color='#FFFFFF',
background_color='#000000',
#background_image=Image.open(bg_image) if bg_image else None,
update_streamlit=realtime_update,
height=28*9,
width=28*9,
drawing_mode='freedraw',
key="canvas",
)
if canvas_result.image_data is not None:
st.write('### Resized Image')
st.write("The image needs to be resized, because it can only input 28x28 images")
# st.image(canvas_result.image_data)
# st.write(type(canvas_result.image_data))
# st.write(canvas_result.image_data.shape)
# st.write(canvas_result.image_data)
im = ImageOps.grayscale(Image.fromarray(canvas_result.image_data.astype(
'uint8'), mode="RGBA")).resize((28, 28))
# img_data = im.
st.image(im, width=28*9)
data = img_to_array(im)
data = data / 255
data = data.reshape(1, 28, 28, 1)
data = data.astype('float32')
st.write('### Predicted Digit')
prediction = model.predict(data)
result = plt.figure(figsize=(12, 3))
plt.bar(range(10), prediction[0])
plt.xticks(range(10))
plt.xlabel('Digit')
plt.ylabel('Probability')
plt.title('Drawing Prediction')
plt.ylim(0, 1)
st.write(result)