Spaces:
Sleeping
Sleeping
import gradio as gr | |
import plotly.graph_objs as go | |
from sklearn.cluster import KMeans | |
import os | |
def _get_cluster_colors(cluster_centers) -> list[str]: | |
cluster_colors = [] | |
for r, g, b in cluster_centers: | |
cluster_colors.append(f"rgb({r},{g},{b})") | |
return cluster_colors | |
def plot_image(Image, N_Clusters: int) -> gr.Plot: | |
img_flat = Image.reshape(-1, 3) | |
kmeans = KMeans(N_Clusters, random_state=1).fit(img_flat) | |
cluster_colors = _get_cluster_colors(kmeans.cluster_centers_) | |
fig = go.Figure(data=[go.Scatter3d( | |
x=kmeans.cluster_centers_[:, 0], | |
y=kmeans.cluster_centers_[:, 1], | |
z=kmeans.cluster_centers_[:, 2], | |
mode='markers', | |
marker=dict( | |
color=cluster_colors, # Set marker color to cluster colors | |
opacity=0.9, | |
) | |
)]) | |
# Adjust layout, including axis labels | |
fig.update_layout( | |
scene=dict( | |
xaxis_title="Red Channel", | |
yaxis_title="Green Channel", | |
zaxis_title="Blue Channel" | |
), | |
margin=dict(l=0, r=0, b=0, t=0) | |
) | |
return gr.Plot(fig) | |
interface = gr.Interface( | |
fn=plot_image, | |
title="3D RGB Cluster Visualization", | |
inputs=[gr.Image(), gr.Slider(minimum=20, maximum=500)], | |
outputs=gr.Plot(), | |
examples=[[os.path.join("examples", image), 100] for image in os.listdir("examples")], | |
cache_examples="lazy", | |
) | |
if __name__ == "__main__": | |
interface.launch() | |