File size: 1,464 Bytes
69e0433
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45901b5
69e0433
 
 
c7bede3
69e0433
 
 
 
73bed59
69e0433
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import gradio as gr
import plotly.graph_objs as go
from sklearn.cluster import KMeans
import os


def _get_cluster_colors(cluster_centers) -> list[str]:
    cluster_colors = []
    for r, g, b in cluster_centers:
        cluster_colors.append(f"rgb({r},{g},{b})")
    return cluster_colors


def plot_image(Image, N_Clusters: int) -> gr.Plot:
    img_flat = Image.reshape(-1, 3)
    kmeans = KMeans(N_Clusters, random_state=1).fit(img_flat)
    cluster_colors = _get_cluster_colors(kmeans.cluster_centers_)

    fig = go.Figure(data=[go.Scatter3d(
        x=kmeans.cluster_centers_[:, 0],
        y=kmeans.cluster_centers_[:, 1],
        z=kmeans.cluster_centers_[:, 2],
        mode='markers',
        marker=dict(
            color=cluster_colors,  # Set marker color to cluster colors
            opacity=0.9,
        )
    )])

    # Adjust layout, including axis labels
    fig.update_layout(
        scene=dict(
            xaxis_title="Red Channel",
            yaxis_title="Green Channel",
            zaxis_title="Blue Channel"
        ),
        margin=dict(l=0, r=0, b=0, t=0)
    )
    return gr.Plot(fig)


interface = gr.Interface(
    fn=plot_image,
    title="3D RGB Cluster Visualization",
    inputs=[gr.Image(), gr.Slider(minimum=20, maximum=500)],
    outputs=gr.Plot(),
    examples=[[os.path.join("examples", image), 100] for image in os.listdir("examples")],
    cache_examples="lazy",
)


if __name__ == "__main__":
    interface.launch()