Andrei Cozma
Updates
50413ea
raw
history blame
12.6 kB
import os
from typing import Tuple, Union
import gradio as gr
import numpy as np
from PIL import Image, ImageChops, ImageOps
class ImageInfo:
def __init__(
self,
size: Tuple[int, int],
channels: int,
data_type: str,
min_val: float,
max_val: float,
):
self.size = size
self.channels = channels
self.data_type = data_type
self.min_val = min_val
self.max_val = max_val
@classmethod
def from_pil(cls, pil_image: Image.Image) -> "ImageInfo":
size = (pil_image.width, pil_image.height)
channels = len(pil_image.getbands())
data_type = str(pil_image.mode)
extrema = pil_image.getextrema()
if channels > 1: # Multi-band image
min_val = min([band[0] for band in extrema])
max_val = max([band[1] for band in extrema])
else: # Single-band image
min_val, max_val = extrema
return cls(size, channels, data_type, min_val, max_val)
@classmethod
def from_numpy(cls, np_array: np.ndarray) -> "ImageInfo":
if len(np_array.shape) > 3:
raise ValueError(f"Unsupported array shape: {np_array.shape}")
size = (np_array.shape[1], np_array.shape[0])
channels = 1 if len(np_array.shape) == 2 else np_array.shape[2]
data_type = str(np_array.dtype)
min_val, max_val = np_array.min(), np_array.max()
return cls(size, channels, data_type, min_val, max_val)
@classmethod
def from_any(cls, image: Union[Image.Image, np.ndarray]) -> "ImageInfo":
if isinstance(image, np.ndarray):
return cls.from_numpy(image)
elif isinstance(image, Image.Image):
return cls.from_pil(image)
else:
raise ValueError(f"Unsupported image type: {type(image)}")
def __str__(self) -> str:
return f"{str(self.size)} {self.channels}C {self.data_type} {round(self.min_val, 2)}min/{round(self.max_val, 2)}max"
@property
def aspect_ratio(self) -> float:
return self.size[0] / self.size[1]
def nextpow2(n):
"""Find the next power of 2 greater than or equal to `n`."""
return int(2 ** np.ceil(np.log2(n)))
def pad_image_nextpow2(image):
print("-" * 80)
print("pad_image_nextpow2: ")
print(ImageInfo.from_any(image))
if image.ndim == 2:
image = image[:, :, np.newaxis]
assert image.ndim == 3, f"Expected image.ndim == 3. Got {image.ndim}"
height, width, channels = image.shape
height_new = nextpow2(height)
width_new = nextpow2(width)
height_diff = height_new - height
width_diff = width_new - width
image = np.pad(
image,
(
(height_diff // 2, height_diff - height_diff // 2),
(width_diff // 2, width_diff - width_diff // 2),
(0, 0),
),
mode="constant",
# mode="edge",
# mode="linear_ramp",
# mode="maximum",
# mode="mean",
# mode="median",
# mode="minimum",
# mode="reflect",
# mode="symmetric",
# mode="wrap",
# mode="empty",
)
print(ImageInfo.from_any(image))
return image
def get_fft(image):
print("-" * 80)
print("get_fft: ")
print("image:", ImageInfo.from_any(image))
fft = np.fft.fft2(image, axes=(0, 1, 2))
fft = np.fft.fftshift(fft)
return fft
def get_ifft_image(fft):
print("-" * 80)
print("get_ifft_image: ")
ifft = np.fft.ifftshift(fft)
ifft = np.fft.ifft2(ifft, axes=(0, 1, 2))
# we only need the real part
ifft_image = np.real(ifft)
# remove padding
# ifft = ifft[
# h_diff // 2 : h_diff // 2 + original_shape[0],
# w_diff // 2 : w_diff // 2 + original_shape[1],
# ]
ifft_image = (ifft_image - np.min(ifft_image)) / (
np.max(ifft_image) - np.min(ifft_image)
)
ifft_image = ifft_image * 255
ifft_image = ifft_image.astype(np.uint8)
return ifft_image
def fft_mag_image(fft):
print("-" * 80)
print("fft_mag_image: ")
fft_mag = np.abs(fft)
fft_mag = np.log(fft_mag + 1)
# scale 0 to 1
fft_mag = (fft_mag - np.min(fft_mag)) / (np.max(fft_mag) - np.min(fft_mag) + 1e-6)
# scale to (0, 255)
fft_mag = fft_mag * 255
fft_mag = fft_mag.astype(np.uint8)
return fft_mag
def fft_phase_image(fft):
print("-" * 80)
print("fft_phase_image: ")
fft_phase = np.angle(fft)
fft_phase = fft_phase + np.pi
fft_phase = fft_phase / (2 * np.pi)
# scale 0 to 1
fft_phase = (fft_phase - np.min(fft_phase)) / (
np.max(fft_phase) - np.min(fft_phase)
)
# scale to (0, 255)
fft_phase = fft_phase * 255
fft_phase = fft_phase.astype(np.uint8)
return fft_phase
def onclick_process_fft(state, inp_image, mask_opacity, inverted_mask, pad):
print("-" * 80)
print("onclick_process_fft:")
if isinstance(inp_image, dict):
if "image" not in inp_image:
raise gr.Error("Please upload or select an image first.")
image, mask = inp_image["image"], inp_image["mask"]
print("image:", ImageInfo.from_any(image))
print("mask:", ImageInfo.from_any(image))
image = Image.fromarray(image)
mask = Image.fromarray(mask)
if not inverted_mask:
mask = ImageOps.invert(mask)
image_final = ImageChops.multiply(image, mask)
image_final = Image.blend(image, image_final, mask_opacity)
image_final = image_final.convert(image.mode)
image_final = np.array(image_final)
elif isinstance(inp_image, np.ndarray):
image_final = inp_image
else:
raise gr.Error("Please upload or select an image first.")
print("image_final:", ImageInfo.from_any(image_final))
if pad:
image_final = pad_image_nextpow2(image_final)
state["inp_image"] = image_final
image_mag = fft_mag_image(get_fft(image_final))
image_phase = fft_phase_image(get_fft(image_final))
return (
[
(image_final, "Input Image (Final)"),
(image_mag, "FFT Magnitude (Original)"),
(image_phase, "FFT Phase (Original)"),
],
image_mag,
image_phase,
)
def onclick_process_ifft(state, mag_and_mask, phase_and_mask):
print("-" * 80)
print("onclick_process_ifft:")
if state["inp_image"] is None:
raise gr.Error("Please process FFT first.")
image = state["inp_image"]
# h_new = nextpow2(original_shape[0])
# w_new = nextpow2(original_shape[1])
# h_diff = h_new - original_shape[0]
# w_diff = w_new - original_shape[1]
mask_mag = mag_and_mask["mask"]
print("mag_mask:", ImageInfo.from_any(mask_mag))
mask_phase = phase_and_mask["mask"]
print("phase_mask:", ImageInfo.from_any(mask_phase))
fft = get_fft(state["inp_image"])
print(f"fft: {fft.shape}")
fft_mag = np.where(mask_mag == 255, 0, np.abs(fft))
fft_phase = np.where(mask_phase == 255, 0, np.angle(fft))
fft = fft_mag * np.exp(1j * fft_phase)
ifft_image = get_ifft_image(fft)
image_mag = fft_mag_image(fft)
image_phase = fft_phase_image(fft)
return (
[
(image, "Input Image (Final)"),
(image_mag, "FFT Magnitude (Filtered)"),
(image_phase, "FFT Phase (Filtered)"),
],
ifft_image,
)
def get_start_image():
return (np.ones((512, 512, 3)) * 255).astype(np.uint8)
def update_image_input(state, selection):
print("-" * 80)
print("update_image_input:")
print(f"selection: {selection}")
if not selection:
white_image = get_start_image()
return (
white_image,
[white_image],
None,
None,
None,
)
image_path = os.path.join("./images", selection)
print(f"image_path: {image_path}")
if not os.path.exists(image_path):
raise gr.Error(f"Image not found: {image_path}")
image = Image.open(image_path)
image = np.array(image)
state["inp_image"] = image
return (
image,
[image],
None,
None,
None,
)
def clear_image_input(state):
print("-" * 80)
print("clear_image_input:")
state["inp_image"] = None
return (
None,
[],
None,
None,
None,
)
css = """
.fft_mag > .image-container > button > div:first-child {
display: none;
}
.fft_phase > .image-container > button > div:first-child {
display: none;
}
.ifft_img > .image-container > button > div:first-child {
display: none;
}
"""
with gr.Blocks(css=css) as demo:
state = gr.State(
{
"inp_image": None,
},
)
with gr.Row():
with gr.Column():
inp_image = gr.Image(
value=get_start_image(),
label="Input Image",
height=512,
type="numpy",
interactive=True,
tool="sketch",
mask_opacity=1.0,
elem_classes=["inp_img"],
)
files = os.listdir("./images")
files = sorted(files)
inp_samples = gr.Dropdown(
choices=files,
label="Select Example Image",
)
with gr.Column():
out_gallery = gr.Gallery(
label="Input Gallery",
height=512,
rows=1,
columns=3,
allow_preview=True,
preview=False,
selected_index=None,
)
with gr.Row():
inp_mask_opacity = gr.Slider(
label="Mask Opacity",
minimum=0.0,
maximum=1.0,
step=0.05,
value=1.0,
)
inp_invert_mask = gr.Checkbox(
label="Invert Mask",
value=False,
)
inp_pad = gr.Checkbox(
label="Pad NextPow2",
value=True,
)
btn_fft = gr.Button("Process FFT")
out_fft_mag = gr.Image(
label="FFT Magnitude Spectrum",
height=512,
type="numpy",
interactive=True,
# source="canvas",
tool="sketch",
mask_opacity=1.0,
elem_classes=["fft_mag"],
)
out_fft_phase = gr.Image(
label="FFT Phase Spectrum",
height=512,
type="numpy",
interactive=True,
# source="canvas",
tool="sketch",
mask_opacity=1.0,
elem_classes=["fft_phase"],
)
btn_ifft = gr.Button("Process IFFT")
out_ifft = gr.Image(
label="IFFT",
height=512,
type="numpy",
interactive=True,
show_download_button=True,
elem_classes=["ifft_img"],
)
inp_image.clear(
clear_image_input,
[state],
[inp_samples, out_gallery, out_fft_mag, out_fft_phase, out_ifft],
)
# Set up event listener for the Dropdown component to update the image input
inp_samples.change(
update_image_input,
[state, inp_samples],
[inp_image, out_gallery, out_fft_mag, out_fft_phase, out_ifft],
)
# Set up events for fft processing
btn_fft.click(
onclick_process_fft,
[state, inp_image, inp_mask_opacity, inp_invert_mask, inp_pad],
[out_gallery, out_fft_mag, out_fft_phase],
)
out_fft_mag.clear(
onclick_process_fft,
[state, inp_image, inp_mask_opacity, inp_invert_mask, inp_pad],
[out_gallery, out_fft_mag, out_fft_phase],
)
out_fft_phase.clear(
onclick_process_fft,
[state, inp_image, inp_mask_opacity, inp_invert_mask, inp_pad],
[out_gallery, out_fft_mag, out_fft_phase],
)
# inp_image.edit(
# get_fft_images,
# [state, inp_image],
# [out_gallery, out_fft_mag, out_fft_phase],
# )
# Set up events for ifft processing
btn_ifft.click(
onclick_process_ifft,
[state, out_fft_mag, out_fft_phase],
[out_gallery, out_ifft],
)
# out_fft_mag.edit(
# get_ifft_image,
# [state, out_fft_mag, out_fft_phase],
# [out_ifft],
# )
# out_fft_phase.edit(
# get_ifft_image,
# [state, out_fft_mag, out_fft_phase],
# [out_ifft],
# )
if __name__ == "__main__":
demo.launch()