hwberry2 commited on
Commit
d27ee2d
·
1 Parent(s): 2c7bae6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -11
app.py CHANGED
@@ -14,7 +14,7 @@ with gr.Blocks() as demo:
14
  * Train/Eval will setup, train, and evaluate the base model
15
  """)
16
 
17
- def modelTraining():
18
  mnist = tf.keras.datasets.mnist
19
 
20
  (x_train, y_train), (x_test, y_test) = mnist.load_data()
@@ -42,14 +42,13 @@ with gr.Blocks() as demo:
42
  model.fit(x_train, y_train, epochs=5)
43
  test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)]
44
 
45
- result = "Test accuracy: ", test_acc
46
 
47
- return result
48
-
49
- def predict_image(img):
50
  # Define any necessary preprocessing steps for the image input here
 
 
51
  # Make a prediction using the model
52
- prediction = model.predict(img)
53
 
54
  # Postprocess the prediction and return it
55
  return prediction
@@ -57,14 +56,12 @@ with gr.Blocks() as demo:
57
 
58
  # Creates the Gradio interface objects
59
  with gr.Row():
60
- with gr.Column(scale=1):
61
- submit_btn = gr.Button(value="Train/Eval")
62
  with gr.Column(scale=2):
 
 
63
  model_performance = gr.Text(label="Model Performance", interactive=False)
64
  model_prediction = gr.Text(label="Model Prediction", interactive=False)
65
- image_data = gr.Image(label="Upload Image", type="numpy")
66
- submit_btn.click(modelTraining, [], model_performance)
67
- image_data.change(predict_image, image_data, model_prediction)
68
 
69
 
70
  # creates a local web server
 
14
  * Train/Eval will setup, train, and evaluate the base model
15
  """)
16
 
17
+ def modelTraining(img):
18
  mnist = tf.keras.datasets.mnist
19
 
20
  (x_train, y_train), (x_test, y_test) = mnist.load_data()
 
42
  model.fit(x_train, y_train, epochs=5)
43
  test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)]
44
 
45
+ print "Test accuracy: ", test_acc
46
 
 
 
 
47
  # Define any necessary preprocessing steps for the image input here
48
+ probability_model = tf.keras.Sequential([model,
49
+ tf.keras.layers.Softmax()])
50
  # Make a prediction using the model
51
+ prediction = probability_model.predict(img)
52
 
53
  # Postprocess the prediction and return it
54
  return prediction
 
56
 
57
  # Creates the Gradio interface objects
58
  with gr.Row():
 
 
59
  with gr.Column(scale=2):
60
+ image_data = gr.Image(label="Upload Image", type="numpy")
61
+ with gr.Column(scale=1):
62
  model_performance = gr.Text(label="Model Performance", interactive=False)
63
  model_prediction = gr.Text(label="Model Prediction", interactive=False)
64
+ image_data.change(modelTraining, image_data, model_prediction)
 
 
65
 
66
 
67
  # creates a local web server