hwberry2's picture
Update app.py
2c7bae6
raw
history blame
2.3 kB
import gradio as gr
import tensorflow as tf
# Create a Gradio App using Blocks
with gr.Blocks() as demo:
gr.Markdown(
"""
# AI/ML Playground
"""
)
with gr.Accordion("Click for Instructions:"):
gr.Markdown(
"""
* Train/Eval will setup, train, and evaluate the base model
""")
def modelTraining():
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10)
])
predictions = model(x_train[:1]).numpy()
tf.nn.softmax(predictions).numpy()
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
loss_fn(y_train[:1], predictions).numpy()
model.compile(optimizer='adam',
loss=loss_fn,
metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5)
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)]
result = "Test accuracy: ", test_acc
return result
def predict_image(img):
# Define any necessary preprocessing steps for the image input here
# Make a prediction using the model
prediction = model.predict(img)
# Postprocess the prediction and return it
return prediction
# Creates the Gradio interface objects
with gr.Row():
with gr.Column(scale=1):
submit_btn = gr.Button(value="Train/Eval")
with gr.Column(scale=2):
model_performance = gr.Text(label="Model Performance", interactive=False)
model_prediction = gr.Text(label="Model Prediction", interactive=False)
image_data = gr.Image(label="Upload Image", type="numpy")
submit_btn.click(modelTraining, [], model_performance)
image_data.change(predict_image, image_data, model_prediction)
# creates a local web server
# if share=True creates a public
# demo on huggingface.co
demo.launch(share=False)