Babyloncoder's picture
Update app.py
44ca1c5 verified
import gradio as gr
from transformers import pipeline
from PIL import Image
import requests
import numpy as np
import pandas as pd
from plottable import Table
import matplotlib.pyplot as plt
from io import BytesIO
import random
def classify_image(upload, url, labels):
"""
Classify the image either from an uploaded file or a URL with given labels.
"""
# Check if an image file is uploaded
if upload is not None:
# Read the uploaded file as a byte stream
image = Image.open(BytesIO(upload))
# Otherwise, load the image from the provided URL
elif url is not None:
image = Image.open(requests.get(url, stream=True).raw)
# If neither, return a message prompting for an input
else:
return "Please upload an image or enter an image URL."
# Split the labels by comma and strip whitespace
labels_list = [label.strip() for label in labels.split(',')]
# Load the image classification model
image_classifier = pipeline(task="zero-shot-image-classification", model="google/siglip-so400m-patch14-384")
# Perform inference
outputs = image_classifier(image, candidate_labels=labels_list)
# Process outputs
labels = [output["label"] for output in outputs]
scores = [output["score"] for output in outputs]
# Normalize scores to sum up to 100%
total_score = sum(scores)
normalized_scores = [round(score * 100 / total_score, 2) for score in scores]
# Plot the horizontal bar chart with different colors for each label
plt.figure(figsize=(10, 6))
colors = [plt.cm.viridis(i/len(labels)) for i in range(len(labels))]
plt.barh(labels, normalized_scores, color=colors)
plt.xlabel('Score (%)')
plt.ylabel('Labels')
plt.title('Classification Results')
plt.gca().invert_yaxis() # Invert y-axis to display labels from top to bottom
plt.tight_layout()
# Save the plot to a BytesIO object
buf = BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
# Convert BytesIO object to image
result_image = Image.open(buf)
# Create a DataFrame for the classification results
df = pd.DataFrame({"Labels": labels, "Scores (%)": normalized_scores})
# Create a plottable table
tab = Table(df)
# Plot the table using matplotlib
fig, ax = plt.subplots(figsize=(6, 5))
ax.axis('tight')
ax.axis('off')
ax.table(cellText=df.values, colLabels=df.columns, loc='center')
# Save the figure to a BytesIO object
buf_table = BytesIO()
plt.savefig(buf_table, format='png')
buf_table.seek(0)
# Convert BytesIO object to image
result_table_image = Image.open(buf_table)
return result_image, result_table_image
# Create the Gradio interface
interface = gr.Interface(
fn=classify_image,
inputs=[
gr.File(type="binary", label="Upload Image"),
gr.Textbox(label="Or, enter Image URL"),
gr.Textbox(label="Enter labels separated by commas (e.g., animal, human, building)")
],
outputs=[
gr.Image(label="Classification Results (Bar Chart)"),
gr.Image(label="Classification Results (Table)")
],
title="Image Classifier",
description="Upload an image or enter an image URL, then specify labels to classify the image."
)
# Launch the interface
interface.launch()