Mattral commited on
Commit
4601fa2
·
verified ·
1 Parent(s): b48d0fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -64
app.py CHANGED
@@ -4,10 +4,30 @@ import cv2
4
  import plotly.graph_objects as go
5
  from plotly.subplots import make_subplots
6
  import pandas as pd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  # FFT processing functions
9
  def apply_fft(image):
10
- """Apply FFT to each channel of the image and return shifted FFT channels."""
11
  fft_channels = []
12
  for channel in cv2.split(image):
13
  fft = np.fft.fft2(channel)
@@ -16,7 +36,6 @@ def apply_fft(image):
16
  return fft_channels
17
 
18
  def filter_fft_percentage(fft_channels, percentage):
19
- """Filter FFT channels to keep top percentage of magnitudes."""
20
  filtered_fft = []
21
  for fft_data in fft_channels:
22
  magnitude = np.abs(fft_data)
@@ -28,7 +47,6 @@ def filter_fft_percentage(fft_channels, percentage):
28
  return filtered_fft
29
 
30
  def inverse_fft(filtered_fft):
31
- """Reconstruct image from filtered FFT channels."""
32
  reconstructed_channels = []
33
  for fft_data in filtered_fft:
34
  fft_ishift = np.fft.ifftshift(fft_data)
@@ -37,8 +55,22 @@ def inverse_fft(filtered_fft):
37
  reconstructed_channels.append(img_normalized.astype(np.uint8))
38
  return cv2.merge(reconstructed_channels)
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  def create_3d_plot(fft_channels, downsample_factor=1):
41
- """Create interactive 3D surface plots using Plotly."""
42
  fig = make_subplots(
43
  rows=3, cols=2,
44
  specs=[[{'type': 'scene'}, {'type': 'scene'}],
@@ -51,33 +83,26 @@ def create_3d_plot(fft_channels, downsample_factor=1):
51
  )
52
  )
53
 
54
- channel_names = ['Blue', 'Green', 'Red']
55
-
56
  for i, fft_data in enumerate(fft_channels):
57
- # Downsample data for better performance
58
  fft_down = fft_data[::downsample_factor, ::downsample_factor]
59
  magnitude = np.abs(fft_down)
60
  phase = np.angle(fft_down)
61
 
62
- # Create grid coordinates
63
  rows, cols = magnitude.shape
64
  x = np.linspace(-cols//2, cols//2, cols)
65
  y = np.linspace(-rows//2, rows//2, rows)
66
  X, Y = np.meshgrid(x, y)
67
 
68
- # Magnitude plot
69
  fig.add_trace(
70
  go.Surface(x=X, y=Y, z=magnitude, colorscale='Viridis', showscale=False),
71
  row=i+1, col=1
72
  )
73
 
74
- # Phase plot
75
  fig.add_trace(
76
  go.Surface(x=X, y=Y, z=phase, colorscale='Inferno', showscale=False),
77
  row=i+1, col=2
78
  )
79
 
80
- # Update layout for better visualization
81
  fig.update_layout(
82
  height=1500,
83
  width=1200,
@@ -93,80 +118,104 @@ def create_3d_plot(fft_channels, downsample_factor=1):
93
 
94
  # Streamlit UI
95
  st.set_page_config(layout="wide")
96
- st.title("Interactive Frequency Domain Analysis")
97
-
98
- # Introduction Text
99
- st.subheader("Introduction to FFT and Image Filtering")
100
- st.write(
101
- """Fast Fourier Transform (FFT) is a technique to transform an image from the spatial domain to the frequency domain.
102
- In this domain, each frequency represents a different aspect of the image's texture and details.
103
- By filtering out certain frequencies, you can modify the image's appearance, enhancing or suppressing certain features."""
104
- )
105
-
 
 
 
106
  uploaded_file = st.file_uploader("Upload an image", type=['png', 'jpg', 'jpeg'])
107
 
108
  if uploaded_file is not None:
109
- # Read and display original image
110
  file_bytes = np.frombuffer(uploaded_file.getvalue(), np.uint8)
111
  image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
112
  image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
113
  st.image(image_rgb, caption="Original Image", use_column_width=True)
114
 
115
- # Process FFT and store in session state
116
- if 'fft_channels' not in st.session_state:
117
  st.session_state.fft_channels = apply_fft(image)
118
 
119
- # Create a form to submit frequency percentage selection
120
- with st.form(key='fft_form'):
121
- percentage = st.slider(
122
- "Percentage of frequencies to retain:",
123
- min_value=0.1, max_value=100.0, value=10.0, step=0.1,
124
- help="Adjust the slider to select what portion of frequency components to keep. Lower values blur the image."
125
- )
126
- submit_button = st.form_submit_button(label="Apply Filter")
127
 
128
- if submit_button:
129
- # Apply filtering and reconstruct image
130
- filtered_fft = filter_fft_percentage(st.session_state.fft_channels, percentage)
131
- reconstructed = inverse_fft(filtered_fft)
132
- reconstructed_rgb = cv2.cvtColor(reconstructed, cv2.COLOR_BGR2RGB)
 
 
 
 
133
  st.image(reconstructed_rgb, caption="Reconstructed Image", use_column_width=True)
134
 
135
- # Display FFT Data in Table Format
136
  st.subheader("Frequency Data of Each Channel")
137
- fft_data_dict = {}
138
  for i, channel_name in enumerate(['Blue', 'Green', 'Red']):
139
- magnitude = np.abs(st.session_state.fft_channels[i])
140
- phase = np.angle(st.session_state.fft_channels[i])
141
- fft_data_dict[channel_name] = {'Magnitude': magnitude, 'Phase': phase}
142
-
143
- # Create DataFrames for each channel's FFT data
144
- for channel_name, data in fft_data_dict.items():
145
  st.write(f"### {channel_name} Channel FFT Data")
146
- magnitude_df = pd.DataFrame(data['Magnitude'])
147
- phase_df = pd.DataFrame(data['Phase'])
148
  st.write("#### Magnitude Data:")
149
- st.dataframe(magnitude_df.head(10)) # Display first 10 rows for brevity
150
  st.write("#### Phase Data:")
151
- st.dataframe(phase_df.head(10)) # Display first 10 rows for brevity
152
-
153
- # Download button for reconstructed image
154
- _, encoded_img = cv2.imencode('.png', reconstructed)
155
- st.download_button(
156
- "Download Reconstructed Image",
157
- encoded_img.tobytes(),
158
- "reconstructed.png",
159
- "image/png"
160
- )
161
 
162
- # 3D visualization controls
163
  st.subheader("3D Frequency Components Visualization")
164
  downsample = st.slider(
165
  "Downsampling factor for 3D plots:",
166
- min_value=1, max_value=20, value=5,
167
- help="Controls the resolution of the 3D surface plots. Higher values improve performance but reduce the plot's detail."
168
  )
169
-
170
- # Generate and display 3D plots
171
- fig = create_3d_plot(filtered_fft, downsample)
172
  st.plotly_chart(fig, use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import plotly.graph_objects as go
5
  from plotly.subplots import make_subplots
6
  import pandas as pd
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ # Dummy CNN Model
12
+ class SimpleCNN(nn.Module):
13
+ def __init__(self):
14
+ super(SimpleCNN, self).__init__()
15
+ self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
16
+ self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
17
+ self.fc1 = nn.Linear(32 * 8 * 8, 128)
18
+ self.fc2 = nn.Linear(128, 10)
19
+
20
+ def forward(self, x):
21
+ x1 = F.relu(self.conv1(x)) # First conv layer activation
22
+ x2 = F.relu(self.conv2(x1))
23
+ x3 = F.adaptive_avg_pool2d(x2, (8, 8))
24
+ x4 = x3.view(x3.size(0), -1)
25
+ x5 = F.relu(self.fc1(x4))
26
+ x6 = self.fc2(x5)
27
+ return x6, x1 # Return both output and first layer activations
28
 
29
  # FFT processing functions
30
  def apply_fft(image):
 
31
  fft_channels = []
32
  for channel in cv2.split(image):
33
  fft = np.fft.fft2(channel)
 
36
  return fft_channels
37
 
38
  def filter_fft_percentage(fft_channels, percentage):
 
39
  filtered_fft = []
40
  for fft_data in fft_channels:
41
  magnitude = np.abs(fft_data)
 
47
  return filtered_fft
48
 
49
  def inverse_fft(filtered_fft):
 
50
  reconstructed_channels = []
51
  for fft_data in filtered_fft:
52
  fft_ishift = np.fft.ifftshift(fft_data)
 
55
  reconstructed_channels.append(img_normalized.astype(np.uint8))
56
  return cv2.merge(reconstructed_channels)
57
 
58
+ # CNN Pass Visualization
59
+ def pass_to_cnn(fft_image):
60
+ model = SimpleCNN()
61
+ magnitude_tensor = torch.tensor(np.abs(fft_image), dtype=torch.float32).unsqueeze(0).unsqueeze(0)
62
+
63
+ with torch.no_grad():
64
+ output, activations = model(magnitude_tensor)
65
+
66
+ # Ensure activations have the correct shape [batch_size, channels, height, width]
67
+ if len(activations.shape) == 3:
68
+ activations = activations.unsqueeze(0) # Add batch dimension if missing
69
+
70
+ return activations, magnitude_tensor
71
+
72
+ # 3D plotting function
73
  def create_3d_plot(fft_channels, downsample_factor=1):
 
74
  fig = make_subplots(
75
  rows=3, cols=2,
76
  specs=[[{'type': 'scene'}, {'type': 'scene'}],
 
83
  )
84
  )
85
 
 
 
86
  for i, fft_data in enumerate(fft_channels):
 
87
  fft_down = fft_data[::downsample_factor, ::downsample_factor]
88
  magnitude = np.abs(fft_down)
89
  phase = np.angle(fft_down)
90
 
 
91
  rows, cols = magnitude.shape
92
  x = np.linspace(-cols//2, cols//2, cols)
93
  y = np.linspace(-rows//2, rows//2, rows)
94
  X, Y = np.meshgrid(x, y)
95
 
 
96
  fig.add_trace(
97
  go.Surface(x=X, y=Y, z=magnitude, colorscale='Viridis', showscale=False),
98
  row=i+1, col=1
99
  )
100
 
 
101
  fig.add_trace(
102
  go.Surface(x=X, y=Y, z=phase, colorscale='Inferno', showscale=False),
103
  row=i+1, col=2
104
  )
105
 
 
106
  fig.update_layout(
107
  height=1500,
108
  width=1200,
 
118
 
119
  # Streamlit UI
120
  st.set_page_config(layout="wide")
121
+ st.title("Interactive Frequency Domain Analysis with CNN")
122
+
123
+ # Initialize session state
124
+ if 'fft_channels' not in st.session_state:
125
+ st.session_state.fft_channels = None
126
+ if 'filtered_fft' not in st.session_state:
127
+ st.session_state.filtered_fft = None
128
+ if 'reconstructed' not in st.session_state:
129
+ st.session_state.reconstructed = None
130
+ if 'show_cnn' not in st.session_state:
131
+ st.session_state.show_cnn = False
132
+
133
+ # Upload image
134
  uploaded_file = st.file_uploader("Upload an image", type=['png', 'jpg', 'jpeg'])
135
 
136
  if uploaded_file is not None:
 
137
  file_bytes = np.frombuffer(uploaded_file.getvalue(), np.uint8)
138
  image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
139
  image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
140
  st.image(image_rgb, caption="Original Image", use_column_width=True)
141
 
142
+ # Apply FFT and store in session state
143
+ if st.session_state.fft_channels is None:
144
  st.session_state.fft_channels = apply_fft(image)
145
 
146
+ # Frequency percentage slider
147
+ percentage = st.slider(
148
+ "Percentage of frequencies to retain:",
149
+ 0.1, 100.0, 10.0, 0.1,
150
+ help="Adjust the slider to select what portion of frequency components to keep."
151
+ )
 
 
152
 
153
+ # Apply FFT filter
154
+ if st.button("Apply Filter"):
155
+ st.session_state.filtered_fft = filter_fft_percentage(st.session_state.fft_channels, percentage)
156
+ st.session_state.reconstructed = inverse_fft(st.session_state.filtered_fft)
157
+ st.session_state.show_cnn = False # Reset CNN visualization
158
+
159
+ # Display reconstructed image and FFT data
160
+ if st.session_state.reconstructed is not None:
161
+ reconstructed_rgb = cv2.cvtColor(st.session_state.reconstructed, cv2.COLOR_BGR2RGB)
162
  st.image(reconstructed_rgb, caption="Reconstructed Image", use_column_width=True)
163
 
164
+ # FFT Data Tables
165
  st.subheader("Frequency Data of Each Channel")
 
166
  for i, channel_name in enumerate(['Blue', 'Green', 'Red']):
 
 
 
 
 
 
167
  st.write(f"### {channel_name} Channel FFT Data")
168
+ magnitude_df = pd.DataFrame(np.abs(st.session_state.filtered_fft[i]))
169
+ phase_df = pd.DataFrame(np.angle(st.session_state.filtered_fft[i]))
170
  st.write("#### Magnitude Data:")
171
+ st.dataframe(magnitude_df.head(10))
172
  st.write("#### Phase Data:")
173
+ st.dataframe(phase_df.head(10))
 
 
 
 
 
 
 
 
 
174
 
175
+ # 3D Visualization
176
  st.subheader("3D Frequency Components Visualization")
177
  downsample = st.slider(
178
  "Downsampling factor for 3D plots:",
179
+ 1, 20, 5,
180
+ help="Controls the resolution of the 3D surface plots."
181
  )
182
+ fig = create_3d_plot(st.session_state.filtered_fft, downsample)
 
 
183
  st.plotly_chart(fig, use_container_width=True)
184
+
185
+ # CNN Visualization Section
186
+ if st.button("Pass to CNN"):
187
+ st.session_state.show_cnn = True
188
+
189
+ if st.session_state.show_cnn:
190
+ st.subheader("CNN Processing Visualization")
191
+ activations, magnitude_tensor = pass_to_cnn(st.session_state.filtered_fft[0])
192
+
193
+ # Display input tensor
194
+ st.write("### Input Magnitude Tensor:")
195
+ st.image(magnitude_tensor.squeeze().numpy(),
196
+ caption="Magnitude Tensor",
197
+ use_column_width=True,
198
+ clamp=True)
199
+
200
+ # Display activations
201
+ st.write("### First Convolution Layer Activations")
202
+ activation = activations.detach().numpy()
203
+
204
+ # Check the shape of the activation tensor
205
+ if len(activation.shape) == 4: # [batch_size, channels, height, width]
206
+ for i in range(activation.shape[1]): # Loop through channels
207
+ act_img = activation[0, i, :, :] # Select the first batch and current channel
208
+ act_img_normalized = (act_img - act_img.min()) / (act_img.max() - act_img.min()) # Normalize
209
+
210
+ # Display activation map
211
+ st.write(f"#### Activation Channel {i+1}")
212
+ st.image(act_img_normalized,
213
+ caption=f"Activation Channel {i+1}",
214
+ use_column_width=True)
215
+
216
+ # Display activation values in a table
217
+ st.write("##### Activation Values:")
218
+ activation_df = pd.DataFrame(act_img)
219
+ st.dataframe(activation_df)
220
+ else:
221
+ st.error(f"Unexpected activation shape: {activation.shape}. Expected 4 dimensions.")