Mattral's picture
Create app.py
1219288 verified
raw
history blame
6.88 kB
import streamlit as st
import numpy as np
import cv2
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd
# FFT processing functions
def apply_fft(image):
"""Apply FFT to each channel of the image and return shifted FFT channels."""
fft_channels = []
for channel in cv2.split(image):
fft = np.fft.fft2(channel)
fft_shifted = np.fft.fftshift(fft)
fft_channels.append(fft_shifted)
return fft_channels
def filter_fft_percentage(fft_channels, percentage):
"""Filter FFT channels to keep top percentage of magnitudes."""
filtered_fft = []
for fft_data in fft_channels:
magnitude = np.abs(fft_data)
sorted_mag = np.sort(magnitude.flatten())[::-1]
num_keep = int(len(sorted_mag) * percentage / 100)
threshold = sorted_mag[num_keep - 1] if num_keep > 0 else 0
mask = magnitude >= threshold
filtered_fft.append(fft_data * mask)
return filtered_fft
def inverse_fft(filtered_fft):
"""Reconstruct image from filtered FFT channels."""
reconstructed_channels = []
for fft_data in filtered_fft:
fft_ishift = np.fft.ifftshift(fft_data)
img_reconstructed = np.fft.ifft2(fft_ishift).real
img_normalized = cv2.normalize(img_reconstructed, None, 0, 255, cv2.NORM_MINMAX)
reconstructed_channels.append(img_normalized.astype(np.uint8))
return cv2.merge(reconstructed_channels)
def create_3d_plot(fft_channels, downsample_factor=1):
"""Create interactive 3D surface plots using Plotly."""
fig = make_subplots(
rows=3, cols=2,
specs=[[{'type': 'scene'}, {'type': 'scene'}],
[{'type': 'scene'}, {'type': 'scene'}],
[{'type': 'scene'}, {'type': 'scene'}]],
subplot_titles=(
'Blue - Magnitude', 'Blue - Phase',
'Green - Magnitude', 'Green - Phase',
'Red - Magnitude', 'Red - Phase'
)
)
channel_names = ['Blue', 'Green', 'Red']
for i, fft_data in enumerate(fft_channels):
# Downsample data for better performance
fft_down = fft_data[::downsample_factor, ::downsample_factor]
magnitude = np.abs(fft_down)
phase = np.angle(fft_down)
# Create grid coordinates
rows, cols = magnitude.shape
x = np.linspace(-cols//2, cols//2, cols)
y = np.linspace(-rows//2, rows//2, rows)
X, Y = np.meshgrid(x, y)
# Magnitude plot
fig.add_trace(
go.Surface(x=X, y=Y, z=magnitude, colorscale='Viridis', showscale=False),
row=i+1, col=1
)
# Phase plot
fig.add_trace(
go.Surface(x=X, y=Y, z=phase, colorscale='Inferno', showscale=False),
row=i+1, col=2
)
# Update layout for better visualization
fig.update_layout(
height=1500,
width=1200,
margin=dict(l=0, r=0, b=0, t=30),
scene_camera=dict(eye=dict(x=1.5, y=1.5, z=0.5)),
scene=dict(
xaxis=dict(title='Frequency X'),
yaxis=dict(title='Frequency Y'),
zaxis=dict(title='Magnitude/Phase')
)
)
return fig
# Streamlit UI
st.set_page_config(layout="wide")
st.title("Interactive Frequency Domain Analysis")
# Introduction Text
st.subheader("Introduction to FFT and Image Filtering")
st.write(
"""Fast Fourier Transform (FFT) is a technique to transform an image from the spatial domain to the frequency domain.
In this domain, each frequency represents a different aspect of the image's texture and details.
By filtering out certain frequencies, you can modify the image's appearance, enhancing or suppressing certain features."""
)
uploaded_file = st.file_uploader("Upload an image", type=['png', 'jpg', 'jpeg'])
if uploaded_file is not None:
# Read and display original image
file_bytes = np.frombuffer(uploaded_file.getvalue(), np.uint8)
image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
st.image(image_rgb, caption="Original Image", use_column_width=True)
# Process FFT and store in session state
if 'fft_channels' not in st.session_state:
st.session_state.fft_channels = apply_fft(image)
# Create a form to submit frequency percentage selection
with st.form(key='fft_form'):
percentage = st.slider(
"Percentage of frequencies to retain:",
min_value=0.1, max_value=100.0, value=10.0, step=0.1,
help="Adjust the slider to select what portion of frequency components to keep. Lower values blur the image."
)
submit_button = st.form_submit_button(label="Apply Filter")
if submit_button:
# Apply filtering and reconstruct image
filtered_fft = filter_fft_percentage(st.session_state.fft_channels, percentage)
reconstructed = inverse_fft(filtered_fft)
reconstructed_rgb = cv2.cvtColor(reconstructed, cv2.COLOR_BGR2RGB)
st.image(reconstructed_rgb, caption="Reconstructed Image", use_column_width=True)
# Display FFT Data in Table Format
st.subheader("Frequency Data of Each Channel")
fft_data_dict = {}
for i, channel_name in enumerate(['Blue', 'Green', 'Red']):
magnitude = np.abs(st.session_state.fft_channels[i])
phase = np.angle(st.session_state.fft_channels[i])
fft_data_dict[channel_name] = {'Magnitude': magnitude, 'Phase': phase}
# Create DataFrames for each channel's FFT data
for channel_name, data in fft_data_dict.items():
st.write(f"### {channel_name} Channel FFT Data")
magnitude_df = pd.DataFrame(data['Magnitude'])
phase_df = pd.DataFrame(data['Phase'])
st.write("#### Magnitude Data:")
st.dataframe(magnitude_df.head(10)) # Display first 10 rows for brevity
st.write("#### Phase Data:")
st.dataframe(phase_df.head(10)) # Display first 10 rows for brevity
# Download button for reconstructed image
_, encoded_img = cv2.imencode('.png', reconstructed)
st.download_button(
"Download Reconstructed Image",
encoded_img.tobytes(),
"reconstructed.png",
"image/png"
)
# 3D visualization controls
st.subheader("3D Frequency Components Visualization")
downsample = st.slider(
"Downsampling factor for 3D plots:",
min_value=1, max_value=20, value=5,
help="Controls the resolution of the 3D surface plots. Higher values improve performance but reduce the plot's detail."
)
# Generate and display 3D plots
fig = create_3d_plot(filtered_fft, downsample)
st.plotly_chart(fig, use_container_width=True)