File size: 4,342 Bytes
505a4da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import warnings
from io import BytesIO

import matplotlib as mpl
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.axes_grid1 import make_axes_locatable

import streamlit as st
from img2cmap import ImageConverter

warnings.filterwarnings("ignore")
st.set_option("deprecation.showfileUploaderEncoding", False)

st.set_page_config(
    page_title="img2cmap web",
    layout="wide",
)

st.title("Convert images to a colormap")
st.markdown(
    """
    This app converts images to colormaps using the Python
    library [img2cmap](https://github.com/arvkevi/img2cmap).
    Try your own image on the left. **Scroll down to generate an optimal colormap.**
    """
)

st.sidebar.markdown("### Image settings")
file_or_url = st.sidebar.radio("Upload an image file or paste an image URL", ("file", "url"))

if file_or_url == "file":
    user_image = st.sidebar.file_uploader("Upload an image file")
    if user_image is not None:
        user_image = BytesIO(user_image.getvalue())
elif file_or_url == "url":
    user_image = st.sidebar.text_input("Paste an image URL", "https://static1.bigstockphoto.com/3/2/3/large1500/323952496.jpg")
else:
    st.warning("Please select an option")

# default image to use
if user_image is None:
    user_image = "https://raw.githubusercontent.com/arvkevi/img2cmap/main/tests/images/south_beach_sunset.jpg"

# user settings
st.sidebar.markdown("### User settings")
n_colors = st.sidebar.number_input(
    "Number of colors", min_value=2, max_value=20, value=5, help="The number of colors to return in the colormap"
)
n_colors = int(n_colors)
remove_transparent = st.sidebar.checkbox(
    "Remove transparency", False, help="If checked, remove transparent pixels from the image before clustering."
)
random_state = st.sidebar.number_input("Random state", value=42, help="Random state for reproducibility")
random_state = int(random_state)


@st.cache(allow_output_mutation=True)
def get_image_converter(user_image, remove_transparent):
    converter = ImageConverter(user_image, remove_transparent=remove_transparent)
    return converter


converter = get_image_converter(user_image, remove_transparent)

with st.spinner("Generating colormap..."):
    cmap = converter.generate_cmap(n_colors=n_colors, palette_name="", random_state=random_state)

# plot the image and colorbar
fig1, ax1 = plt.subplots(figsize=(8, 8))

ax1.axis("off")
img = converter.image
im = ax1.imshow(img, cmap=cmap)

divider = make_axes_locatable(ax1)
cax = divider.append_axes("right", size="10%", pad=0.05)

cb = fig1.colorbar(im, cax=cax, orientation="vertical", label=cmap.name)
cb.set_ticks([])
st.pyplot(fig1)

colors1 = [mpl.colors.rgb2hex(c) for c in cmap.colors]
st.text("Hex Codes (click to copy on far right)")
st.code(colors1)


st.header("Detect optimal number of colors")
optimize = st.button("Optimize")
if optimize:
    with st.spinner("Optimizing... (this takes about a minute)"):
        cmaps, best_n_colors, ssd = converter.generate_optimal_cmap(max_colors=20, palette_name="", random_state=random_state)

    figopt, ax = plt.subplots(figsize=(7, 5))

    ymax = 21
    xmax = 20
    ax.set_ylim(2, ymax)
    ax.set_xlim(0, 20)

    # i will be y axis
    for y, cmap_ in cmaps.items():
        colors = sorted([mpl.colors.rgb2hex(c) for c in cmap_.colors])
        intervals, width = np.linspace(0, xmax, len(colors) + 1, retstep=True)
        # j will be x axis
        for j, color in enumerate(colors):
            rect = patches.Rectangle((intervals[j], y), width, 1, facecolor=color)
            ax.add_patch(rect)

    ax.set_yticks(np.arange(2, ymax) + 0.5)
    ax.set_yticklabels(np.arange(2, ymax))
    ax.set_ylabel("Number of colors")
    ax.set_xticks([])

    # best
    rect = patches.Rectangle((0, best_n_colors), ymax, 1, linewidth=1, facecolor="none", edgecolor="black", linestyle="--")
    ax.add_patch(rect)

    # minus 2, one for starting at 2 and one for 0-indexing
    ax.get_yticklabels()[best_n_colors - 2].set_color("red")
    st.pyplot(figopt)
    st.metric("Optimal number of colors", best_n_colors)
    st.text("Hex Codes of optimal colormap (click to copy on far right)")
    st.code(sorted([mpl.colors.rgb2hex(c) for c in cmaps[best_n_colors].colors]))

    st.text("Sum of squared distances by number of colors:")
    st.write(ssd)