Spaces:
Sleeping
Sleeping
import numpy as np | |
import tensorflow as tf | |
import tensorflow.keras as keras | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
from huggingface_hub import from_pretrained_keras | |
# download the already pushed model | |
trained_models = [from_pretrained_keras("buio/attention_mil_classification")] | |
POSITIVE_CLASS = 1 | |
BAG_COUNT = 1000 | |
VAL_BAG_COUNT = 300 | |
BAG_SIZE = 3 | |
PLOT_SIZE = 1 | |
ENSEMBLE_AVG_COUNT = 1 | |
def create_bags(input_data, input_labels, positive_class, bag_count, instance_count): | |
# Set up bags. | |
bags = [] | |
bag_labels = [] | |
# Normalize input data. | |
input_data = np.divide(input_data, 255.0) | |
# Count positive samples. | |
count = 0 | |
for _ in range(bag_count): | |
# Pick a fixed size random subset of samples. | |
index = np.random.choice(input_data.shape[0], instance_count, replace=False) | |
instances_data = input_data[index] | |
instances_labels = input_labels[index] | |
# By default, all bags are labeled as 0. | |
bag_label = 0 | |
# Check if there is at least a positive class in the bag. | |
if positive_class in instances_labels: | |
# Positive bag will be labeled as 1. | |
bag_label = 1 | |
count += 1 | |
bags.append(instances_data) | |
bag_labels.append(np.array([bag_label])) | |
print(f"Positive bags: {count}") | |
print(f"Negative bags: {bag_count - count}") | |
return (list(np.swapaxes(bags, 0, 1)), np.array(bag_labels)) | |
# Load the MNIST dataset. | |
(x_train, y_train), (x_val, y_val) = keras.datasets.mnist.load_data() | |
# Create validation data. | |
val_data, val_labels = create_bags( | |
x_val, y_val, POSITIVE_CLASS, VAL_BAG_COUNT, BAG_SIZE | |
) | |
def predict(data, labels, trained_models): | |
# Collect info per model. | |
models_predictions = [] | |
models_attention_weights = [] | |
models_losses = [] | |
models_accuracies = [] | |
for model in trained_models: | |
# Predict output classes on data. | |
predictions = model.predict(data) | |
models_predictions.append(predictions) | |
# Create intermediate model to get MIL attention layer weights. | |
intermediate_model = keras.Model(model.input, model.get_layer("alpha").output) | |
# Predict MIL attention layer weights. | |
intermediate_predictions = intermediate_model.predict(data) | |
attention_weights = np.squeeze(np.swapaxes(intermediate_predictions, 1, 0)) | |
models_attention_weights.append(attention_weights) | |
model.compile(loss="sparse_categorical_crossentropy", metrics=["accuracy"]) | |
loss, accuracy = model.evaluate(data, labels, verbose=0) | |
models_losses.append(loss) | |
models_accuracies.append(accuracy) | |
print( | |
f"The average loss and accuracy are {np.sum(models_losses, axis=0) / ENSEMBLE_AVG_COUNT:.2f}" | |
f" and {100 * np.sum(models_accuracies, axis=0) / ENSEMBLE_AVG_COUNT:.2f} % resp." | |
) | |
return ( | |
np.sum(models_predictions, axis=0) / ENSEMBLE_AVG_COUNT, | |
np.sum(models_attention_weights, axis=0) / ENSEMBLE_AVG_COUNT, | |
) | |
def plot(data, labels, bag_class, predictions=None, attention_weights=None): | |
""""Utility for plotting bags and attention weights. | |
Args: | |
data: Input data that contains the bags of instances. | |
labels: The associated bag labels of the input data. | |
bag_class: String name of the desired bag class. | |
The options are: "positive" or "negative". | |
predictions: Class labels model predictions. | |
If you don't specify anything, ground truth labels will be used. | |
attention_weights: Attention weights for each instance within the input data. | |
If you don't specify anything, the values won't be displayed. | |
""" | |
labels = np.array(labels).reshape(-1) | |
if bag_class == "positive": | |
if predictions is not None: | |
labels = np.where(predictions.argmax(1) == 1)[0] | |
else: | |
labels = np.where(labels == 1)[0] | |
random_labels = np.random.choice(labels, PLOT_SIZE) | |
bags = np.array(data)[:, random_labels] | |
elif bag_class == "negative": | |
if predictions is not None: | |
labels = np.where(predictions.argmax(1) == 0)[0] | |
else: | |
labels = np.where(labels == 0)[0] | |
random_labels = np.random.choice(labels, PLOT_SIZE) | |
bags = np.array(data)[:, random_labels] | |
else: | |
print(f"There is no class {bag_class}") | |
return | |
print(f"The bag class label is {bag_class}") | |
for i in range(PLOT_SIZE): | |
figure = plt.figure(figsize=(8, 8)) #each image | |
print(f"Bag number: {labels[i]}") | |
for j in range(BAG_SIZE): | |
image = bags[j][i] | |
figure.add_subplot(1, BAG_SIZE, j + 1) | |
plt.grid(False) | |
plt.axis('off') | |
if attention_weights is not None: | |
plt.title(np.around(attention_weights[random_labels[i]][j], 2)) | |
plt.imshow(image) | |
plt.show() | |
return figure | |
# Evaluate and predict classes and attention scores on validation data. | |
def predict_and_plot(class_): | |
print('WTF') | |
class_predictions, attention_params = predict(val_data, val_labels, trained_models) | |
PLOT_SIZE = 1 | |
return plot(val_data, val_labels, class_, | |
predictions=class_predictions, | |
attention_weights=attention_params) | |
predict_and_plot('positive') | |
inputs = gr.Radio(choices=['positive','negative']) | |
outputs = gr.Plot(label='predicted bag') | |
#title = "Heart Disease Classification 🩺❤️" | |
#description = "Binary classification of structured data including numerical and categorical features." | |
#article = "Author: <a href=\"https://huggingface.co/buio\">Marco Buiani</a>. Based on the <a href=\"https://keras.io/examples/structured_data/structured_data_classification_from_scratch/\">keras example</a> by <a href=\"https://twitter.com/fchollet\">François Chollet</a> Model Link: https://huggingface.co/buio/structured-data-classification" | |
demo = gr.Interface(fn=predict_and_plot, inputs=inputs, outputs=outputs, allow_flagging='never') | |
demo.launch(debug=True) |