import gradio as gr from transformers import pipeline from PIL import Image import requests import numpy as np import pandas as pd from plottable import Table import matplotlib.pyplot as plt from io import BytesIO import random def classify_image(upload, url, labels): """ Classify the image either from an uploaded file or a URL with given labels. """ # Check if an image file is uploaded if upload is not None: # Read the uploaded file as a byte stream image = Image.open(BytesIO(upload)) # Otherwise, load the image from the provided URL elif url is not None: image = Image.open(requests.get(url, stream=True).raw) # If neither, return a message prompting for an input else: return "Please upload an image or enter an image URL." # Split the labels by comma and strip whitespace labels_list = [label.strip() for label in labels.split(',')] # Load the image classification model image_classifier = pipeline(task="zero-shot-image-classification", model="google/siglip-so400m-patch14-384") # Perform inference outputs = image_classifier(image, candidate_labels=labels_list) # Process outputs labels = [output["label"] for output in outputs] scores = [output["score"] for output in outputs] # Normalize scores to sum up to 100% total_score = sum(scores) normalized_scores = [round(score * 100 / total_score, 2) for score in scores] # Plot the horizontal bar chart with different colors for each label plt.figure(figsize=(10, 6)) colors = [plt.cm.viridis(i/len(labels)) for i in range(len(labels))] plt.barh(labels, normalized_scores, color=colors) plt.xlabel('Score (%)') plt.ylabel('Labels') plt.title('Classification Results') plt.gca().invert_yaxis() # Invert y-axis to display labels from top to bottom plt.tight_layout() # Save the plot to a BytesIO object buf = BytesIO() plt.savefig(buf, format='png') buf.seek(0) # Convert BytesIO object to image result_image = Image.open(buf) # Create a DataFrame for the classification results df = pd.DataFrame({"Labels": labels, "Scores (%)": normalized_scores}) # Create a plottable table tab = Table(df) # Plot the table using matplotlib fig, ax = plt.subplots(figsize=(6, 5)) ax.axis('tight') ax.axis('off') ax.table(cellText=df.values, colLabels=df.columns, loc='center') # Save the figure to a BytesIO object buf_table = BytesIO() plt.savefig(buf_table, format='png') buf_table.seek(0) # Convert BytesIO object to image result_table_image = Image.open(buf_table) return result_image, result_table_image # Create the Gradio interface interface = gr.Interface( fn=classify_image, inputs=[ gr.File(type="binary", label="Upload Image"), gr.Textbox(label="Or, enter Image URL"), gr.Textbox(label="Enter labels separated by commas (e.g., animal, human, building)") ], outputs=[ gr.Image(label="Classification Results (Bar Chart)"), gr.Image(label="Classification Results (Table)") ], title="Image Classifier", description="Upload an image or enter an image URL, then specify labels to classify the image." ) # Launch the interface interface.launch()