Kyle Dampier
changed order of the layout and added more descriptions
60d926f
raw
history blame
1.87 kB
import tensorflow as tf
import matplotlib.pyplot as plt
from PIL import Image, ImageOps
from tensorflow.keras.utils import img_to_array
from streamlit_drawable_canvas import st_canvas
import streamlit as st
# st.set_page_config(layout="wide")
st.write('# MNIST Digit Recognition')
st.write('## Using trained CNN `Keras` model')
st.write('To view how this model was trained go to the `Files and Versions` tab and download the `Week1.ipynb` notebook')
# Import Pre-trained Model
model = tf.keras.models.load_model('mnist.h5')
plt.rcParams.update({'font.size': 18})
# Create a sidebar to hold the settings
stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 9)
realtime_update = st.sidebar.checkbox("Update in realtime", True)
canvas_result = st_canvas(
fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
stroke_width=stroke_width,
stroke_color='#FFFFFF',
background_color='#000000',
#background_image=Image.open(bg_image) if bg_image else None,
update_streamlit=realtime_update,
height=28*9,
width=28*9,
drawing_mode='freedraw',
key="canvas",
)
if canvas_result.image_data is not None:
im = ImageOps.grayscale(Image.fromarray(canvas_result.image_data.astype(
'uint8'), mode="RGBA")).resize((28, 28))
data = img_to_array(im)
data = data / 255
data = data.reshape(1, 28, 28, 1)
data = data.astype('float32')
st.write('### Predicted Digit')
prediction = model.predict(data)
result = plt.figure(figsize=(12, 3))
plt.bar(range(10), prediction[0])
plt.xticks(range(10))
plt.xlabel('Digit')
plt.ylabel('Probability')
plt.title('Drawing Prediction')
plt.ylim(0, 1)
st.write(result)
st.write('### Resized Image')
st.write("The image needs to be resized, because it can only input 28x28 images")
st.image(im, width=28*9)