File size: 1,939 Bytes
5db5524
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import tensorflow as tf
import matplotlib.pyplot as plt
from PIL import Image, ImageOps
from keras.preprocessing.image import img_to_array

from streamlit_drawable_canvas import st_canvas
import streamlit as st

# st.set_page_config(layout="wide")

st.write('# MNIST Digit Recognition')
st.write('## Using a CNN `Keras` model')

# Import Pre-trained Model
model = tf.keras.models.load_model('mnist.h5')
plt.rcParams.update({'font.size': 18})

# Create a sidebar to hold the settings
stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 9)
realtime_update = st.sidebar.checkbox("Update in realtime", True)


canvas_result = st_canvas(
    fill_color="rgba(255, 165, 0, 0.3)",  # Fixed fill color with some opacity
    stroke_width=stroke_width,
    stroke_color='#FFFFFF',
    background_color='#000000',
    #background_image=Image.open(bg_image) if bg_image else None,
    update_streamlit=realtime_update,
    height=224,
    width=224,
    drawing_mode='freedraw',
    key="canvas",
)

if canvas_result.image_data is not None:
    st.write('### Resized Image')
    st.write("The image needs to be resized, because it can only input 28x28 images")
    # st.image(canvas_result.image_data)
    # st.write(type(canvas_result.image_data))
    # st.write(canvas_result.image_data.shape)
    # st.write(canvas_result.image_data)
    im = ImageOps.grayscale(Image.fromarray(canvas_result.image_data.astype(
        'uint8'), mode="RGBA")).resize((28, 28))
    # img_data = im.
    st.image(im, width=224)

    data = img_to_array(im)
    data = data / 255
    data = data.reshape(1, 28, 28, 1)
    data = data.astype('float32')

    st.write('### Predicted Digit')
    prediction = model.predict(data)

    result = plt.figure(figsize=(12, 3))
    plt.bar(range(10), prediction[0])
    plt.xticks(range(10))
    plt.xlabel('Digit')
    plt.ylabel('Probability')
    plt.title('Drawing Prediction')
    plt.ylim(0, 1)
    st.write(result)