Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -4,10 +4,30 @@ import cv2
|
|
4 |
import plotly.graph_objects as go
|
5 |
from plotly.subplots import make_subplots
|
6 |
import pandas as pd
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
# FFT processing functions
|
9 |
def apply_fft(image):
|
10 |
-
"""Apply FFT to each channel of the image and return shifted FFT channels."""
|
11 |
fft_channels = []
|
12 |
for channel in cv2.split(image):
|
13 |
fft = np.fft.fft2(channel)
|
@@ -16,7 +36,6 @@ def apply_fft(image):
|
|
16 |
return fft_channels
|
17 |
|
18 |
def filter_fft_percentage(fft_channels, percentage):
|
19 |
-
"""Filter FFT channels to keep top percentage of magnitudes."""
|
20 |
filtered_fft = []
|
21 |
for fft_data in fft_channels:
|
22 |
magnitude = np.abs(fft_data)
|
@@ -28,7 +47,6 @@ def filter_fft_percentage(fft_channels, percentage):
|
|
28 |
return filtered_fft
|
29 |
|
30 |
def inverse_fft(filtered_fft):
|
31 |
-
"""Reconstruct image from filtered FFT channels."""
|
32 |
reconstructed_channels = []
|
33 |
for fft_data in filtered_fft:
|
34 |
fft_ishift = np.fft.ifftshift(fft_data)
|
@@ -37,8 +55,22 @@ def inverse_fft(filtered_fft):
|
|
37 |
reconstructed_channels.append(img_normalized.astype(np.uint8))
|
38 |
return cv2.merge(reconstructed_channels)
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
def create_3d_plot(fft_channels, downsample_factor=1):
|
41 |
-
"""Create interactive 3D surface plots using Plotly."""
|
42 |
fig = make_subplots(
|
43 |
rows=3, cols=2,
|
44 |
specs=[[{'type': 'scene'}, {'type': 'scene'}],
|
@@ -51,33 +83,26 @@ def create_3d_plot(fft_channels, downsample_factor=1):
|
|
51 |
)
|
52 |
)
|
53 |
|
54 |
-
channel_names = ['Blue', 'Green', 'Red']
|
55 |
-
|
56 |
for i, fft_data in enumerate(fft_channels):
|
57 |
-
# Downsample data for better performance
|
58 |
fft_down = fft_data[::downsample_factor, ::downsample_factor]
|
59 |
magnitude = np.abs(fft_down)
|
60 |
phase = np.angle(fft_down)
|
61 |
|
62 |
-
# Create grid coordinates
|
63 |
rows, cols = magnitude.shape
|
64 |
x = np.linspace(-cols//2, cols//2, cols)
|
65 |
y = np.linspace(-rows//2, rows//2, rows)
|
66 |
X, Y = np.meshgrid(x, y)
|
67 |
|
68 |
-
# Magnitude plot
|
69 |
fig.add_trace(
|
70 |
go.Surface(x=X, y=Y, z=magnitude, colorscale='Viridis', showscale=False),
|
71 |
row=i+1, col=1
|
72 |
)
|
73 |
|
74 |
-
# Phase plot
|
75 |
fig.add_trace(
|
76 |
go.Surface(x=X, y=Y, z=phase, colorscale='Inferno', showscale=False),
|
77 |
row=i+1, col=2
|
78 |
)
|
79 |
|
80 |
-
# Update layout for better visualization
|
81 |
fig.update_layout(
|
82 |
height=1500,
|
83 |
width=1200,
|
@@ -93,80 +118,104 @@ def create_3d_plot(fft_channels, downsample_factor=1):
|
|
93 |
|
94 |
# Streamlit UI
|
95 |
st.set_page_config(layout="wide")
|
96 |
-
st.title("Interactive Frequency Domain Analysis")
|
97 |
-
|
98 |
-
#
|
99 |
-
|
100 |
-
st.
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
|
|
|
|
|
|
106 |
uploaded_file = st.file_uploader("Upload an image", type=['png', 'jpg', 'jpeg'])
|
107 |
|
108 |
if uploaded_file is not None:
|
109 |
-
# Read and display original image
|
110 |
file_bytes = np.frombuffer(uploaded_file.getvalue(), np.uint8)
|
111 |
image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
|
112 |
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
113 |
st.image(image_rgb, caption="Original Image", use_column_width=True)
|
114 |
|
115 |
-
#
|
116 |
-
if
|
117 |
st.session_state.fft_channels = apply_fft(image)
|
118 |
|
119 |
-
#
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
)
|
126 |
-
submit_button = st.form_submit_button(label="Apply Filter")
|
127 |
|
128 |
-
|
129 |
-
|
130 |
-
filtered_fft = filter_fft_percentage(st.session_state.fft_channels, percentage)
|
131 |
-
reconstructed = inverse_fft(filtered_fft)
|
132 |
-
|
|
|
|
|
|
|
|
|
133 |
st.image(reconstructed_rgb, caption="Reconstructed Image", use_column_width=True)
|
134 |
|
135 |
-
#
|
136 |
st.subheader("Frequency Data of Each Channel")
|
137 |
-
fft_data_dict = {}
|
138 |
for i, channel_name in enumerate(['Blue', 'Green', 'Red']):
|
139 |
-
magnitude = np.abs(st.session_state.fft_channels[i])
|
140 |
-
phase = np.angle(st.session_state.fft_channels[i])
|
141 |
-
fft_data_dict[channel_name] = {'Magnitude': magnitude, 'Phase': phase}
|
142 |
-
|
143 |
-
# Create DataFrames for each channel's FFT data
|
144 |
-
for channel_name, data in fft_data_dict.items():
|
145 |
st.write(f"### {channel_name} Channel FFT Data")
|
146 |
-
magnitude_df = pd.DataFrame(
|
147 |
-
phase_df = pd.DataFrame(
|
148 |
st.write("#### Magnitude Data:")
|
149 |
-
st.dataframe(magnitude_df.head(10))
|
150 |
st.write("#### Phase Data:")
|
151 |
-
st.dataframe(phase_df.head(10))
|
152 |
-
|
153 |
-
# Download button for reconstructed image
|
154 |
-
_, encoded_img = cv2.imencode('.png', reconstructed)
|
155 |
-
st.download_button(
|
156 |
-
"Download Reconstructed Image",
|
157 |
-
encoded_img.tobytes(),
|
158 |
-
"reconstructed.png",
|
159 |
-
"image/png"
|
160 |
-
)
|
161 |
|
162 |
-
# 3D
|
163 |
st.subheader("3D Frequency Components Visualization")
|
164 |
downsample = st.slider(
|
165 |
"Downsampling factor for 3D plots:",
|
166 |
-
|
167 |
-
help="Controls the resolution of the 3D surface plots.
|
168 |
)
|
169 |
-
|
170 |
-
# Generate and display 3D plots
|
171 |
-
fig = create_3d_plot(filtered_fft, downsample)
|
172 |
st.plotly_chart(fig, use_container_width=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import plotly.graph_objects as go
|
5 |
from plotly.subplots import make_subplots
|
6 |
import pandas as pd
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
# Dummy CNN Model
|
12 |
+
class SimpleCNN(nn.Module):
|
13 |
+
def __init__(self):
|
14 |
+
super(SimpleCNN, self).__init__()
|
15 |
+
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
|
16 |
+
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
|
17 |
+
self.fc1 = nn.Linear(32 * 8 * 8, 128)
|
18 |
+
self.fc2 = nn.Linear(128, 10)
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
x1 = F.relu(self.conv1(x)) # First conv layer activation
|
22 |
+
x2 = F.relu(self.conv2(x1))
|
23 |
+
x3 = F.adaptive_avg_pool2d(x2, (8, 8))
|
24 |
+
x4 = x3.view(x3.size(0), -1)
|
25 |
+
x5 = F.relu(self.fc1(x4))
|
26 |
+
x6 = self.fc2(x5)
|
27 |
+
return x6, x1 # Return both output and first layer activations
|
28 |
|
29 |
# FFT processing functions
|
30 |
def apply_fft(image):
|
|
|
31 |
fft_channels = []
|
32 |
for channel in cv2.split(image):
|
33 |
fft = np.fft.fft2(channel)
|
|
|
36 |
return fft_channels
|
37 |
|
38 |
def filter_fft_percentage(fft_channels, percentage):
|
|
|
39 |
filtered_fft = []
|
40 |
for fft_data in fft_channels:
|
41 |
magnitude = np.abs(fft_data)
|
|
|
47 |
return filtered_fft
|
48 |
|
49 |
def inverse_fft(filtered_fft):
|
|
|
50 |
reconstructed_channels = []
|
51 |
for fft_data in filtered_fft:
|
52 |
fft_ishift = np.fft.ifftshift(fft_data)
|
|
|
55 |
reconstructed_channels.append(img_normalized.astype(np.uint8))
|
56 |
return cv2.merge(reconstructed_channels)
|
57 |
|
58 |
+
# CNN Pass Visualization
|
59 |
+
def pass_to_cnn(fft_image):
|
60 |
+
model = SimpleCNN()
|
61 |
+
magnitude_tensor = torch.tensor(np.abs(fft_image), dtype=torch.float32).unsqueeze(0).unsqueeze(0)
|
62 |
+
|
63 |
+
with torch.no_grad():
|
64 |
+
output, activations = model(magnitude_tensor)
|
65 |
+
|
66 |
+
# Ensure activations have the correct shape [batch_size, channels, height, width]
|
67 |
+
if len(activations.shape) == 3:
|
68 |
+
activations = activations.unsqueeze(0) # Add batch dimension if missing
|
69 |
+
|
70 |
+
return activations, magnitude_tensor
|
71 |
+
|
72 |
+
# 3D plotting function
|
73 |
def create_3d_plot(fft_channels, downsample_factor=1):
|
|
|
74 |
fig = make_subplots(
|
75 |
rows=3, cols=2,
|
76 |
specs=[[{'type': 'scene'}, {'type': 'scene'}],
|
|
|
83 |
)
|
84 |
)
|
85 |
|
|
|
|
|
86 |
for i, fft_data in enumerate(fft_channels):
|
|
|
87 |
fft_down = fft_data[::downsample_factor, ::downsample_factor]
|
88 |
magnitude = np.abs(fft_down)
|
89 |
phase = np.angle(fft_down)
|
90 |
|
|
|
91 |
rows, cols = magnitude.shape
|
92 |
x = np.linspace(-cols//2, cols//2, cols)
|
93 |
y = np.linspace(-rows//2, rows//2, rows)
|
94 |
X, Y = np.meshgrid(x, y)
|
95 |
|
|
|
96 |
fig.add_trace(
|
97 |
go.Surface(x=X, y=Y, z=magnitude, colorscale='Viridis', showscale=False),
|
98 |
row=i+1, col=1
|
99 |
)
|
100 |
|
|
|
101 |
fig.add_trace(
|
102 |
go.Surface(x=X, y=Y, z=phase, colorscale='Inferno', showscale=False),
|
103 |
row=i+1, col=2
|
104 |
)
|
105 |
|
|
|
106 |
fig.update_layout(
|
107 |
height=1500,
|
108 |
width=1200,
|
|
|
118 |
|
119 |
# Streamlit UI
|
120 |
st.set_page_config(layout="wide")
|
121 |
+
st.title("Interactive Frequency Domain Analysis with CNN")
|
122 |
+
|
123 |
+
# Initialize session state
|
124 |
+
if 'fft_channels' not in st.session_state:
|
125 |
+
st.session_state.fft_channels = None
|
126 |
+
if 'filtered_fft' not in st.session_state:
|
127 |
+
st.session_state.filtered_fft = None
|
128 |
+
if 'reconstructed' not in st.session_state:
|
129 |
+
st.session_state.reconstructed = None
|
130 |
+
if 'show_cnn' not in st.session_state:
|
131 |
+
st.session_state.show_cnn = False
|
132 |
+
|
133 |
+
# Upload image
|
134 |
uploaded_file = st.file_uploader("Upload an image", type=['png', 'jpg', 'jpeg'])
|
135 |
|
136 |
if uploaded_file is not None:
|
|
|
137 |
file_bytes = np.frombuffer(uploaded_file.getvalue(), np.uint8)
|
138 |
image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
|
139 |
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
140 |
st.image(image_rgb, caption="Original Image", use_column_width=True)
|
141 |
|
142 |
+
# Apply FFT and store in session state
|
143 |
+
if st.session_state.fft_channels is None:
|
144 |
st.session_state.fft_channels = apply_fft(image)
|
145 |
|
146 |
+
# Frequency percentage slider
|
147 |
+
percentage = st.slider(
|
148 |
+
"Percentage of frequencies to retain:",
|
149 |
+
0.1, 100.0, 10.0, 0.1,
|
150 |
+
help="Adjust the slider to select what portion of frequency components to keep."
|
151 |
+
)
|
|
|
|
|
152 |
|
153 |
+
# Apply FFT filter
|
154 |
+
if st.button("Apply Filter"):
|
155 |
+
st.session_state.filtered_fft = filter_fft_percentage(st.session_state.fft_channels, percentage)
|
156 |
+
st.session_state.reconstructed = inverse_fft(st.session_state.filtered_fft)
|
157 |
+
st.session_state.show_cnn = False # Reset CNN visualization
|
158 |
+
|
159 |
+
# Display reconstructed image and FFT data
|
160 |
+
if st.session_state.reconstructed is not None:
|
161 |
+
reconstructed_rgb = cv2.cvtColor(st.session_state.reconstructed, cv2.COLOR_BGR2RGB)
|
162 |
st.image(reconstructed_rgb, caption="Reconstructed Image", use_column_width=True)
|
163 |
|
164 |
+
# FFT Data Tables
|
165 |
st.subheader("Frequency Data of Each Channel")
|
|
|
166 |
for i, channel_name in enumerate(['Blue', 'Green', 'Red']):
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
st.write(f"### {channel_name} Channel FFT Data")
|
168 |
+
magnitude_df = pd.DataFrame(np.abs(st.session_state.filtered_fft[i]))
|
169 |
+
phase_df = pd.DataFrame(np.angle(st.session_state.filtered_fft[i]))
|
170 |
st.write("#### Magnitude Data:")
|
171 |
+
st.dataframe(magnitude_df.head(10))
|
172 |
st.write("#### Phase Data:")
|
173 |
+
st.dataframe(phase_df.head(10))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
|
175 |
+
# 3D Visualization
|
176 |
st.subheader("3D Frequency Components Visualization")
|
177 |
downsample = st.slider(
|
178 |
"Downsampling factor for 3D plots:",
|
179 |
+
1, 20, 5,
|
180 |
+
help="Controls the resolution of the 3D surface plots."
|
181 |
)
|
182 |
+
fig = create_3d_plot(st.session_state.filtered_fft, downsample)
|
|
|
|
|
183 |
st.plotly_chart(fig, use_container_width=True)
|
184 |
+
|
185 |
+
# CNN Visualization Section
|
186 |
+
if st.button("Pass to CNN"):
|
187 |
+
st.session_state.show_cnn = True
|
188 |
+
|
189 |
+
if st.session_state.show_cnn:
|
190 |
+
st.subheader("CNN Processing Visualization")
|
191 |
+
activations, magnitude_tensor = pass_to_cnn(st.session_state.filtered_fft[0])
|
192 |
+
|
193 |
+
# Display input tensor
|
194 |
+
st.write("### Input Magnitude Tensor:")
|
195 |
+
st.image(magnitude_tensor.squeeze().numpy(),
|
196 |
+
caption="Magnitude Tensor",
|
197 |
+
use_column_width=True,
|
198 |
+
clamp=True)
|
199 |
+
|
200 |
+
# Display activations
|
201 |
+
st.write("### First Convolution Layer Activations")
|
202 |
+
activation = activations.detach().numpy()
|
203 |
+
|
204 |
+
# Check the shape of the activation tensor
|
205 |
+
if len(activation.shape) == 4: # [batch_size, channels, height, width]
|
206 |
+
for i in range(activation.shape[1]): # Loop through channels
|
207 |
+
act_img = activation[0, i, :, :] # Select the first batch and current channel
|
208 |
+
act_img_normalized = (act_img - act_img.min()) / (act_img.max() - act_img.min()) # Normalize
|
209 |
+
|
210 |
+
# Display activation map
|
211 |
+
st.write(f"#### Activation Channel {i+1}")
|
212 |
+
st.image(act_img_normalized,
|
213 |
+
caption=f"Activation Channel {i+1}",
|
214 |
+
use_column_width=True)
|
215 |
+
|
216 |
+
# Display activation values in a table
|
217 |
+
st.write("##### Activation Values:")
|
218 |
+
activation_df = pd.DataFrame(act_img)
|
219 |
+
st.dataframe(activation_df)
|
220 |
+
else:
|
221 |
+
st.error(f"Unexpected activation shape: {activation.shape}. Expected 4 dimensions.")
|