Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import pipeline | |
from PIL import Image | |
import requests | |
import numpy as np | |
import pandas as pd | |
from plottable import Table | |
import matplotlib.pyplot as plt | |
from io import BytesIO | |
import random | |
def classify_image(upload, url, labels): | |
""" | |
Classify the image either from an uploaded file or a URL with given labels. | |
""" | |
# Check if an image file is uploaded | |
if upload is not None: | |
# Read the uploaded file as a byte stream | |
image = Image.open(BytesIO(upload)) | |
# Otherwise, load the image from the provided URL | |
elif url is not None: | |
image = Image.open(requests.get(url, stream=True).raw) | |
# If neither, return a message prompting for an input | |
else: | |
return "Please upload an image or enter an image URL." | |
# Split the labels by comma and strip whitespace | |
labels_list = [label.strip() for label in labels.split(',')] | |
# Load the image classification model | |
image_classifier = pipeline(task="zero-shot-image-classification", model="google/siglip-so400m-patch14-384") | |
# Perform inference | |
outputs = image_classifier(image, candidate_labels=labels_list) | |
# Process outputs | |
labels = [output["label"] for output in outputs] | |
scores = [output["score"] for output in outputs] | |
# Normalize scores to sum up to 100% | |
total_score = sum(scores) | |
normalized_scores = [round(score * 100 / total_score, 2) for score in scores] | |
# Plot the horizontal bar chart with different colors for each label | |
plt.figure(figsize=(10, 6)) | |
colors = [plt.cm.viridis(i/len(labels)) for i in range(len(labels))] | |
plt.barh(labels, normalized_scores, color=colors) | |
plt.xlabel('Score (%)') | |
plt.ylabel('Labels') | |
plt.title('Classification Results') | |
plt.gca().invert_yaxis() # Invert y-axis to display labels from top to bottom | |
plt.tight_layout() | |
# Save the plot to a BytesIO object | |
buf = BytesIO() | |
plt.savefig(buf, format='png') | |
buf.seek(0) | |
# Convert BytesIO object to image | |
result_image = Image.open(buf) | |
# Create a DataFrame for the classification results | |
df = pd.DataFrame({"Labels": labels, "Scores (%)": normalized_scores}) | |
# Create a plottable table | |
tab = Table(df) | |
# Plot the table using matplotlib | |
fig, ax = plt.subplots(figsize=(6, 5)) | |
ax.axis('tight') | |
ax.axis('off') | |
ax.table(cellText=df.values, colLabels=df.columns, loc='center') | |
# Save the figure to a BytesIO object | |
buf_table = BytesIO() | |
plt.savefig(buf_table, format='png') | |
buf_table.seek(0) | |
# Convert BytesIO object to image | |
result_table_image = Image.open(buf_table) | |
return result_image, result_table_image | |
# Create the Gradio interface | |
interface = gr.Interface( | |
fn=classify_image, | |
inputs=[ | |
gr.File(type="binary", label="Upload Image"), | |
gr.Textbox(label="Or, enter Image URL"), | |
gr.Textbox(label="Enter labels separated by commas (e.g., animal, human, building)") | |
], | |
outputs=[ | |
gr.Image(label="Classification Results (Bar Chart)"), | |
gr.Image(label="Classification Results (Table)") | |
], | |
title="Image Classifier", | |
description="Upload an image or enter an image URL, then specify labels to classify the image." | |
) | |
# Launch the interface | |
interface.launch() | |