import streamlit as st import numpy as np import cv2 import plotly.graph_objects as go from plotly.subplots import make_subplots import pandas as pd # FFT processing functions def apply_fft(image): """Apply FFT to each channel of the image and return shifted FFT channels.""" fft_channels = [] for channel in cv2.split(image): fft = np.fft.fft2(channel) fft_shifted = np.fft.fftshift(fft) fft_channels.append(fft_shifted) return fft_channels def filter_fft_percentage(fft_channels, percentage): """Filter FFT channels to keep top percentage of magnitudes.""" filtered_fft = [] for fft_data in fft_channels: magnitude = np.abs(fft_data) sorted_mag = np.sort(magnitude.flatten())[::-1] num_keep = int(len(sorted_mag) * percentage / 100) threshold = sorted_mag[num_keep - 1] if num_keep > 0 else 0 mask = magnitude >= threshold filtered_fft.append(fft_data * mask) return filtered_fft def inverse_fft(filtered_fft): """Reconstruct image from filtered FFT channels.""" reconstructed_channels = [] for fft_data in filtered_fft: fft_ishift = np.fft.ifftshift(fft_data) img_reconstructed = np.fft.ifft2(fft_ishift).real img_normalized = cv2.normalize(img_reconstructed, None, 0, 255, cv2.NORM_MINMAX) reconstructed_channels.append(img_normalized.astype(np.uint8)) return cv2.merge(reconstructed_channels) def create_3d_plot(fft_channels, downsample_factor=1): """Create interactive 3D surface plots using Plotly.""" fig = make_subplots( rows=3, cols=2, specs=[[{'type': 'scene'}, {'type': 'scene'}], [{'type': 'scene'}, {'type': 'scene'}], [{'type': 'scene'}, {'type': 'scene'}]], subplot_titles=( 'Blue - Magnitude', 'Blue - Phase', 'Green - Magnitude', 'Green - Phase', 'Red - Magnitude', 'Red - Phase' ) ) channel_names = ['Blue', 'Green', 'Red'] for i, fft_data in enumerate(fft_channels): # Downsample data for better performance fft_down = fft_data[::downsample_factor, ::downsample_factor] magnitude = np.abs(fft_down) phase = np.angle(fft_down) # Create grid coordinates rows, cols = magnitude.shape x = np.linspace(-cols//2, cols//2, cols) y = np.linspace(-rows//2, rows//2, rows) X, Y = np.meshgrid(x, y) # Magnitude plot fig.add_trace( go.Surface(x=X, y=Y, z=magnitude, colorscale='Viridis', showscale=False), row=i+1, col=1 ) # Phase plot fig.add_trace( go.Surface(x=X, y=Y, z=phase, colorscale='Inferno', showscale=False), row=i+1, col=2 ) # Update layout for better visualization fig.update_layout( height=1500, width=1200, margin=dict(l=0, r=0, b=0, t=30), scene_camera=dict(eye=dict(x=1.5, y=1.5, z=0.5)), scene=dict( xaxis=dict(title='Frequency X'), yaxis=dict(title='Frequency Y'), zaxis=dict(title='Magnitude/Phase') ) ) return fig # Streamlit UI st.set_page_config(layout="wide") st.title("Interactive Frequency Domain Analysis") # Introduction Text st.subheader("Introduction to FFT and Image Filtering") st.write( """Fast Fourier Transform (FFT) is a technique to transform an image from the spatial domain to the frequency domain. In this domain, each frequency represents a different aspect of the image's texture and details. By filtering out certain frequencies, you can modify the image's appearance, enhancing or suppressing certain features.""" ) uploaded_file = st.file_uploader("Upload an image", type=['png', 'jpg', 'jpeg']) if uploaded_file is not None: # Read and display original image file_bytes = np.frombuffer(uploaded_file.getvalue(), np.uint8) image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR) image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) st.image(image_rgb, caption="Original Image", use_column_width=True) # Process FFT and store in session state if 'fft_channels' not in st.session_state: st.session_state.fft_channels = apply_fft(image) # Create a form to submit frequency percentage selection with st.form(key='fft_form'): percentage = st.slider( "Percentage of frequencies to retain:", min_value=0.1, max_value=100.0, value=10.0, step=0.1, help="Adjust the slider to select what portion of frequency components to keep. Lower values blur the image." ) submit_button = st.form_submit_button(label="Apply Filter") if submit_button: # Apply filtering and reconstruct image filtered_fft = filter_fft_percentage(st.session_state.fft_channels, percentage) reconstructed = inverse_fft(filtered_fft) reconstructed_rgb = cv2.cvtColor(reconstructed, cv2.COLOR_BGR2RGB) st.image(reconstructed_rgb, caption="Reconstructed Image", use_column_width=True) # Display FFT Data in Table Format st.subheader("Frequency Data of Each Channel") fft_data_dict = {} for i, channel_name in enumerate(['Blue', 'Green', 'Red']): magnitude = np.abs(st.session_state.fft_channels[i]) phase = np.angle(st.session_state.fft_channels[i]) fft_data_dict[channel_name] = {'Magnitude': magnitude, 'Phase': phase} # Create DataFrames for each channel's FFT data for channel_name, data in fft_data_dict.items(): st.write(f"### {channel_name} Channel FFT Data") magnitude_df = pd.DataFrame(data['Magnitude']) phase_df = pd.DataFrame(data['Phase']) st.write("#### Magnitude Data:") st.dataframe(magnitude_df.head(10)) # Display first 10 rows for brevity st.write("#### Phase Data:") st.dataframe(phase_df.head(10)) # Display first 10 rows for brevity # Download button for reconstructed image _, encoded_img = cv2.imencode('.png', reconstructed) st.download_button( "Download Reconstructed Image", encoded_img.tobytes(), "reconstructed.png", "image/png" ) # 3D visualization controls st.subheader("3D Frequency Components Visualization") downsample = st.slider( "Downsampling factor for 3D plots:", min_value=1, max_value=20, value=5, help="Controls the resolution of the 3D surface plots. Higher values improve performance but reduce the plot's detail." ) # Generate and display 3D plots fig = create_3d_plot(filtered_fft, downsample) st.plotly_chart(fig, use_container_width=True)