ClusteRGB / app.py
Jensen-holm's picture
setting cache_examples='lazy' because it took for ever to build
c7bede3
raw
history blame
1.46 kB
import gradio as gr
import plotly.graph_objs as go
from sklearn.cluster import KMeans
import os
def _get_cluster_colors(cluster_centers) -> list[str]:
cluster_colors = []
for r, g, b in cluster_centers:
cluster_colors.append(f"rgb({r},{g},{b})")
return cluster_colors
def plot_image(Image, N_Clusters: int) -> gr.Plot:
img_flat = Image.reshape(-1, 3)
kmeans = KMeans(N_Clusters, random_state=1).fit(img_flat)
cluster_colors = _get_cluster_colors(kmeans.cluster_centers_)
fig = go.Figure(data=[go.Scatter3d(
x=kmeans.cluster_centers_[:, 0],
y=kmeans.cluster_centers_[:, 1],
z=kmeans.cluster_centers_[:, 2],
mode='markers',
marker=dict(
color=cluster_colors, # Set marker color to cluster colors
opacity=0.9,
)
)])
# Adjust layout, including axis labels
fig.update_layout(
scene=dict(
xaxis_title="Red Channel",
yaxis_title="Green Channel",
zaxis_title="Blue Channel"
),
margin=dict(l=0, r=0, b=0, t=0)
)
return gr.Plot(fig)
interface = gr.Interface(
fn=plot_image,
title="3D RGB Cluster Visualization",
inputs=[gr.Image(), gr.Slider(minimum=20, maximum=500)],
outputs=gr.Plot(),
examples=[[os.path.join("examples", image), 100] for image in os.listdir("examples")],
cache_examples="lazy",
)
if __name__ == "__main__":
interface.launch()