Mattral commited on
Commit
0ce7e6e
·
verified ·
1 Parent(s): 4601fa2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -21
app.py CHANGED
@@ -7,6 +7,8 @@ import pandas as pd
7
  import torch
8
  import torch.nn as nn
9
  import torch.nn.functional as F
 
 
10
 
11
  # Dummy CNN Model
12
  class SimpleCNN(nn.Module):
@@ -182,9 +184,36 @@ if uploaded_file is not None:
182
  fig = create_3d_plot(st.session_state.filtered_fft, downsample)
183
  st.plotly_chart(fig, use_container_width=True)
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  # CNN Visualization Section
186
- if st.button("Pass to CNN"):
187
- st.session_state.show_cnn = True
 
 
 
188
 
189
  if st.session_state.show_cnn:
190
  st.subheader("CNN Processing Visualization")
@@ -197,25 +226,88 @@ if uploaded_file is not None:
197
  use_column_width=True,
198
  clamp=True)
199
 
200
- # Display activations
201
  st.write("### First Convolution Layer Activations")
202
  activation = activations.detach().numpy()
203
 
204
- # Check the shape of the activation tensor
205
- if len(activation.shape) == 4: # [batch_size, channels, height, width]
206
- for i in range(activation.shape[1]): # Loop through channels
207
- act_img = activation[0, i, :, :] # Select the first batch and current channel
208
- act_img_normalized = (act_img - act_img.min()) / (act_img.max() - act_img.min()) # Normalize
209
-
210
- # Display activation map
211
- st.write(f"#### Activation Channel {i+1}")
212
- st.image(act_img_normalized,
213
- caption=f"Activation Channel {i+1}",
214
- use_column_width=True)
215
-
216
- # Display activation values in a table
217
- st.write("##### Activation Values:")
218
- activation_df = pd.DataFrame(act_img)
219
- st.dataframe(activation_df)
220
- else:
221
- st.error(f"Unexpected activation shape: {activation.shape}. Expected 4 dimensions.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import torch
8
  import torch.nn as nn
9
  import torch.nn.functional as F
10
+ import matplotlib.pyplot as plt
11
+ import plotly.express as px
12
 
13
  # Dummy CNN Model
14
  class SimpleCNN(nn.Module):
 
184
  fig = create_3d_plot(st.session_state.filtered_fft, downsample)
185
  st.plotly_chart(fig, use_container_width=True)
186
 
187
+ # Custom CSS to style the button
188
+ st.markdown("""
189
+ <style>
190
+ .centered-button {
191
+ display: flex;
192
+ justify-content: center;
193
+ align-items: center;
194
+ margin-top: 20px;
195
+ }
196
+ .stButton>button {
197
+ padding: 20px 40px;
198
+ font-size: 20px;
199
+ background-color: #4CAF50;
200
+ color: white;
201
+ border: none;
202
+ border-radius: 10px;
203
+ cursor: pointer;
204
+ }
205
+ .stButton>button:hover {
206
+ background-color: #45a049;
207
+ }
208
+ </style>
209
+ """, unsafe_allow_html=True)
210
+
211
  # CNN Visualization Section
212
+ with st.container():
213
+ st.markdown('<div class="centered-button">', unsafe_allow_html=True)
214
+ if st.button("Pass to CNN"):
215
+ st.session_state.show_cnn = True
216
+ st.markdown('</div>', unsafe_allow_html=True)
217
 
218
  if st.session_state.show_cnn:
219
  st.subheader("CNN Processing Visualization")
 
226
  use_column_width=True,
227
  clamp=True)
228
 
229
+ # Display activations with improved visualization
230
  st.write("### First Convolution Layer Activations")
231
  activation = activations.detach().numpy()
232
 
233
+ if len(activation.shape) == 4:
234
+ # Create a grid of activation maps
235
+ cols = 4 # Number of columns in the grid
236
+ rows = 4 # 16 channels / 4 columns = 4 rows
237
+ fig, axs = plt.subplots(rows, cols, figsize=(20, 20))
238
+
239
+ for i in range(activation.shape[1]):
240
+ act_img = activation[0, i, :, :]
241
+ ax = axs[i//cols, i%cols]
242
+ ax.imshow(act_img, cmap='viridis')
243
+ ax.set_title(f'Channel {i+1}')
244
+ ax.axis('off')
245
+
246
+ st.pyplot(fig)
247
+
248
+ # Display sample activation values
249
+ st.write("### Activation Values Sample")
250
+ sample_activation = activation[0, 0, :10, :10] # First 10x10 values
251
+ st.dataframe(pd.DataFrame(sample_activation))
252
+
253
+ # Additional Steps After Activation Channels
254
+ st.markdown("---")
255
+ st.subheader("Next Processing Steps in CNN")
256
+
257
+ # Step 2: Second Convolution Layer Visualization
258
+ st.write("### Second Convolution Layer Features")
259
+ with torch.no_grad():
260
+ model = SimpleCNN()
261
+ output, activations = model(magnitude_tensor)
262
+ second_conv = model.conv2(activations).detach().numpy()
263
+
264
+ if len(second_conv.shape) == 4:
265
+ cols = 8 # 32 channels / 8 columns = 4 rows
266
+ rows = 4
267
+ fig2, axs2 = plt.subplots(rows, cols, figsize=(20, 10))
268
+
269
+ for i in range(second_conv.shape[1]):
270
+ act_img = second_conv[0, i, :, :]
271
+ ax = axs2[i//cols, i%cols]
272
+ ax.imshow(act_img, cmap='plasma')
273
+ ax.set_title(f'Channel {i+1}')
274
+ ax.axis('off')
275
+
276
+ st.pyplot(fig2)
277
+
278
+ # Step 3: Pooling Layer Visualization
279
+ st.write("### Adaptive Average Pooling Output")
280
+ with torch.no_grad():
281
+ pooled = F.adaptive_avg_pool2d(torch.tensor(second_conv), (8, 8)).numpy()
282
+
283
+ st.write("Pooled Features Shape:", pooled.shape)
284
+
285
+ # Normalize and display pooled features
286
+ pooled_sample = pooled[0, 0]
287
+ pooled_normalized = (pooled_sample - pooled_sample.min()) / (pooled_sample.max() - pooled_sample.min())
288
+ st.image(pooled_normalized,
289
+ caption="Sample Pooled Feature Map",
290
+ use_container_width=True,
291
+ clamp=True)
292
+
293
+ # Step 4: Final Classification
294
+ st.write("### Final Classification Scores")
295
+ with torch.no_grad():
296
+ model = SimpleCNN()
297
+ output, _ = model(magnitude_tensor)
298
+ scores = F.softmax(output, dim=1).numpy()
299
+
300
+ classes = [f"Class {i}" for i in range(10)]
301
+ fig3 = px.bar(x=classes, y=scores[0], title="Classification Probabilities")
302
+ st.plotly_chart(fig3)
303
+
304
+ # Step 5: Full Process Explanation
305
+ st.markdown("""
306
+ #### Processing Pipeline:
307
+ 1. Input Magnitude Spectrum →
308
+ 2. Conv1 Features (16 channels) →
309
+ 3. Conv2 Features (32 channels) →
310
+ 4. Pooled Features →
311
+ 5. Fully Connected Layers →
312
+ 6. Final Classification
313
+ """)