import gradio as gr
import tensorflow as tf
import numpy as np
import os
import PIL
import PIL.Image

# Create a Gradio App using Blocks    
with gr.Blocks() as demo:
    gr.Markdown(
    """
    # AI/ML Playground
    """
    )
    with gr.Accordion("Click for Instructions:"):
            gr.Markdown(
    """
    * uploading an image will engage the model in image classsification
    * trained on the following image types: 'T-shirt/top', 'Trouser', 'Pullover', 'Dress',
                                                        'Coat','Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'
    """)

    # Train, evaluate and test a ML
    # image classification model for
    # clothes images

    class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
           'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
    
    # clothing dataset
    mnist = tf.keras.datasets.mnist

    #split the training data in to a train/test sets
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0

    # create the neural net layers
    model = tf.keras.models.Sequential([
      tf.keras.layers.Flatten(input_shape=(28, 28)),
      tf.keras.layers.Dense(128, activation='relu'),
      tf.keras.layers.Dropout(0.2),
      tf.keras.layers.Dense(10)
    ])

    #make a post-training predition on the 
    #training set data
    predictions = model(x_train[:1]).numpy()

    # converts the logits into a probability
    tf.nn.softmax(predictions).numpy()

    #create and train the loss function
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    loss_fn(y_train[:1], predictions).numpy()

    # compile the model with the loss function
    model.compile(optimizer='adam',
                  loss=loss_fn,
                  metrics=['accuracy'])
    
    # train the model - 5 runs
    # evaluate the model on the test set
    model.fit(x_train, y_train, epochs=5, validation_split=0.3)
    test_loss, test_acc = model.evaluate(x_test,  y_test, verbose=2)
    post_train_results = f"Test accuracy: {test_acc} Test Loss: {test_loss}"
    print(post_train_results)

    # create the final model for production
    probability_model = tf.keras.Sequential([model, tf.keras.layers.Softmax()])

        
    def classifyImage(img):     
        # Normalize the pixel values
        img = np.array(img) / 255.0
        
        input_array = np.expand_dims(img, axis=0) # add an extra dimension to represent the batch size

        # Make a prediction using the model
        prediction = probability_model.predict(input_array)

        # Postprocess the prediction and return it
        predicted_label = class_names[np.argmax(prediction)]

        return predicted_label
        
    def do_nothing():
        return
        
    # Creates the Gradio interface objects
    with gr.Row():
        with gr.Column(scale=2):
            image_data = gr.Image(label="Upload Image")
        with gr.Column(scale=1):
            model_prediction = gr.Text(label="Model Prediction", interactive=False)
        image_data.change(classifyImage, image_data, model_prediction)
        image_data.clear(do_nothing)
    
    
# creates a local web s
# if share=True creates a public
# demo on huggingface.c
demo.launch(share=False)