File size: 6,879 Bytes
1219288
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import streamlit as st
import numpy as np
import cv2
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd

# FFT processing functions
def apply_fft(image):
    """Apply FFT to each channel of the image and return shifted FFT channels."""
    fft_channels = []
    for channel in cv2.split(image):
        fft = np.fft.fft2(channel)
        fft_shifted = np.fft.fftshift(fft)
        fft_channels.append(fft_shifted)
    return fft_channels

def filter_fft_percentage(fft_channels, percentage):
    """Filter FFT channels to keep top percentage of magnitudes."""
    filtered_fft = []
    for fft_data in fft_channels:
        magnitude = np.abs(fft_data)
        sorted_mag = np.sort(magnitude.flatten())[::-1]
        num_keep = int(len(sorted_mag) * percentage / 100)
        threshold = sorted_mag[num_keep - 1] if num_keep > 0 else 0
        mask = magnitude >= threshold
        filtered_fft.append(fft_data * mask)
    return filtered_fft

def inverse_fft(filtered_fft):
    """Reconstruct image from filtered FFT channels."""
    reconstructed_channels = []
    for fft_data in filtered_fft:
        fft_ishift = np.fft.ifftshift(fft_data)
        img_reconstructed = np.fft.ifft2(fft_ishift).real
        img_normalized = cv2.normalize(img_reconstructed, None, 0, 255, cv2.NORM_MINMAX)
        reconstructed_channels.append(img_normalized.astype(np.uint8))
    return cv2.merge(reconstructed_channels)

def create_3d_plot(fft_channels, downsample_factor=1):
    """Create interactive 3D surface plots using Plotly."""
    fig = make_subplots(
        rows=3, cols=2,
        specs=[[{'type': 'scene'}, {'type': 'scene'}],
               [{'type': 'scene'}, {'type': 'scene'}],
               [{'type': 'scene'}, {'type': 'scene'}]],
        subplot_titles=(
            'Blue - Magnitude', 'Blue - Phase',
            'Green - Magnitude', 'Green - Phase',
            'Red - Magnitude', 'Red - Phase'
        )
    )

    channel_names = ['Blue', 'Green', 'Red']
    
    for i, fft_data in enumerate(fft_channels):
        # Downsample data for better performance
        fft_down = fft_data[::downsample_factor, ::downsample_factor]
        magnitude = np.abs(fft_down)
        phase = np.angle(fft_down)
        
        # Create grid coordinates
        rows, cols = magnitude.shape
        x = np.linspace(-cols//2, cols//2, cols)
        y = np.linspace(-rows//2, rows//2, rows)
        X, Y = np.meshgrid(x, y)

        # Magnitude plot
        fig.add_trace(
            go.Surface(x=X, y=Y, z=magnitude, colorscale='Viridis', showscale=False),
            row=i+1, col=1
        )
        
        # Phase plot
        fig.add_trace(
            go.Surface(x=X, y=Y, z=phase, colorscale='Inferno', showscale=False),
            row=i+1, col=2
        )

    # Update layout for better visualization
    fig.update_layout(
        height=1500,
        width=1200,
        margin=dict(l=0, r=0, b=0, t=30),
        scene_camera=dict(eye=dict(x=1.5, y=1.5, z=0.5)),
        scene=dict(
            xaxis=dict(title='Frequency X'),
            yaxis=dict(title='Frequency Y'),
            zaxis=dict(title='Magnitude/Phase')
        )
    )
    return fig

# Streamlit UI
st.set_page_config(layout="wide")
st.title("Interactive Frequency Domain Analysis")

# Introduction Text
st.subheader("Introduction to FFT and Image Filtering")
st.write(
    """Fast Fourier Transform (FFT) is a technique to transform an image from the spatial domain to the frequency domain.
    In this domain, each frequency represents a different aspect of the image's texture and details.
    By filtering out certain frequencies, you can modify the image's appearance, enhancing or suppressing certain features."""
)

uploaded_file = st.file_uploader("Upload an image", type=['png', 'jpg', 'jpeg'])

if uploaded_file is not None:
    # Read and display original image
    file_bytes = np.frombuffer(uploaded_file.getvalue(), np.uint8)
    image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    st.image(image_rgb, caption="Original Image", use_column_width=True)

    # Process FFT and store in session state
    if 'fft_channels' not in st.session_state:
        st.session_state.fft_channels = apply_fft(image)

    # Create a form to submit frequency percentage selection
    with st.form(key='fft_form'):
        percentage = st.slider(
            "Percentage of frequencies to retain:",
            min_value=0.1, max_value=100.0, value=10.0, step=0.1,
            help="Adjust the slider to select what portion of frequency components to keep. Lower values blur the image."
        )
        submit_button = st.form_submit_button(label="Apply Filter")

    if submit_button:
        # Apply filtering and reconstruct image
        filtered_fft = filter_fft_percentage(st.session_state.fft_channels, percentage)
        reconstructed = inverse_fft(filtered_fft)
        reconstructed_rgb = cv2.cvtColor(reconstructed, cv2.COLOR_BGR2RGB)
        st.image(reconstructed_rgb, caption="Reconstructed Image", use_column_width=True)

        # Display FFT Data in Table Format
        st.subheader("Frequency Data of Each Channel")
        fft_data_dict = {}
        for i, channel_name in enumerate(['Blue', 'Green', 'Red']):
            magnitude = np.abs(st.session_state.fft_channels[i])
            phase = np.angle(st.session_state.fft_channels[i])
            fft_data_dict[channel_name] = {'Magnitude': magnitude, 'Phase': phase}

        # Create DataFrames for each channel's FFT data
        for channel_name, data in fft_data_dict.items():
            st.write(f"### {channel_name} Channel FFT Data")
            magnitude_df = pd.DataFrame(data['Magnitude'])
            phase_df = pd.DataFrame(data['Phase'])
            st.write("#### Magnitude Data:")
            st.dataframe(magnitude_df.head(10))  # Display first 10 rows for brevity
            st.write("#### Phase Data:")
            st.dataframe(phase_df.head(10))  # Display first 10 rows for brevity

        # Download button for reconstructed image
        _, encoded_img = cv2.imencode('.png', reconstructed)
        st.download_button(
            "Download Reconstructed Image",
            encoded_img.tobytes(),
            "reconstructed.png",
            "image/png"
        )

        # 3D visualization controls
        st.subheader("3D Frequency Components Visualization")
        downsample = st.slider(
            "Downsampling factor for 3D plots:",
            min_value=1, max_value=20, value=5,
            help="Controls the resolution of the 3D surface plots. Higher values improve performance but reduce the plot's detail."
        )
        
        # Generate and display 3D plots
        fig = create_3d_plot(filtered_fft, downsample)
        st.plotly_chart(fig, use_container_width=True)