File size: 2,089 Bytes
5db5524
 
 
e3171cd
5db5524
 
 
 
 
 
 
60d926f
 
5db5524
 
 
a11556c
5db5524
 
 
 
 
 
 
 
 
 
 
 
 
 
1a84122
 
5db5524
 
 
 
 
60d926f
f213495
5db5524
 
 
f213495
5db5524
 
 
 
 
f213495
5db5524
 
 
f213495
5db5524
 
 
 
 
 
 
 
60d926f
f213495
 
a11556c
 
f213495
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import tensorflow as tf
import matplotlib.pyplot as plt
from PIL import Image, ImageOps
from tensorflow.keras.utils import img_to_array

from streamlit_drawable_canvas import st_canvas
import streamlit as st

# st.set_page_config(layout="wide")

st.write('# MNIST Digit Recognition')
st.write('## Using trained CNN `Keras` model')
st.write('To view how this model was trained go to the `Files and Versions` tab and download the `Week1.ipynb` notebook')

# Import Pre-trained Model
model = tf.keras.models.load_model('mnist.h5')
tf.device('/cpu:0')
plt.rcParams.update({'font.size': 18})

# Create a sidebar to hold the settings
stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 9)
realtime_update = st.sidebar.checkbox("Update in realtime", True)


canvas_result = st_canvas(
    fill_color="rgba(255, 165, 0, 0.3)",  # Fixed fill color with some opacity
    stroke_width=stroke_width,
    stroke_color='#FFFFFF',
    background_color='#000000',
    #background_image=Image.open(bg_image) if bg_image else None,
    update_streamlit=realtime_update,
    height=28*9,
    width=28*9,
    drawing_mode='freedraw',
    key="canvas",
)

if canvas_result.image_data is not None:

    # Get image data from canvas
    im = ImageOps.grayscale(Image.fromarray(canvas_result.image_data.astype(
        'uint8'), mode="RGBA")).resize((28, 28))

    # Convert image to array and reshape
    data = img_to_array(im)
    data = data / 255
    data = data.reshape(1, 28, 28, 1)
    data = data.astype('float32')

    # Predict digit
    st.write('### Predicted Digit')
    prediction = model.predict(data)

    # Plot prediction
    result = plt.figure(figsize=(12, 3))
    plt.bar(range(10), prediction[0])
    plt.xticks(range(10))
    plt.xlabel('Digit')
    plt.ylabel('Probability')
    plt.title('Drawing Prediction')
    plt.ylim(0, 1)
    st.write(result)

    # Show resized image
    with st.expander('Show Resized Image'):
        st.write(
            "The image needs to be resized, because it can only input 28x28 images")
        st.image(im, caption='Resized Image', width=28*9)