import gradio as gr import plotly.graph_objs as go from sklearn.cluster import KMeans import os def _get_cluster_colors(cluster_centers) -> list[str]: cluster_colors = [] for r, g, b in cluster_centers: cluster_colors.append(f"rgb({r},{g},{b})") return cluster_colors def plot_image(Image, N_Clusters: int) -> gr.Plot: img_flat = Image.reshape(-1, 3) kmeans = KMeans(N_Clusters, random_state=1).fit(img_flat) cluster_colors = _get_cluster_colors(kmeans.cluster_centers_) fig = go.Figure(data=[go.Scatter3d( x=kmeans.cluster_centers_[:, 0], y=kmeans.cluster_centers_[:, 1], z=kmeans.cluster_centers_[:, 2], mode='markers', marker=dict( color=cluster_colors, # Set marker color to cluster colors opacity=0.9, ) )]) # Adjust layout, including axis labels fig.update_layout( scene=dict( xaxis_title="Red Channel", yaxis_title="Green Channel", zaxis_title="Blue Channel" ), margin=dict(l=0, r=0, b=0, t=0) ) return gr.Plot(fig) interface = gr.Interface( fn=plot_image, title="3D RGB Cluster Visualization", inputs=[gr.Image(), gr.Slider(minimum=20, maximum=500)], outputs=gr.Plot(), examples=[[os.path.join("examples", image), 100] for image in os.listdir("examples")], cache_examples="lazy", ) if __name__ == "__main__": interface.launch()