cyberandy commited on
Commit
5ac398b
·
verified ·
1 Parent(s): 96ca74e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +168 -227
app.py CHANGED
@@ -1,8 +1,110 @@
1
  import gradio as gr
2
  import requests
3
  from typing import Dict, Tuple, List
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- # Define custom CSS with Open Sans font and color theme
6
  css = """
7
  @import url('https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;600;700&display=swap');
8
 
@@ -10,14 +112,6 @@ body {
10
  font-family: 'Open Sans', sans-serif !important;
11
  }
12
 
13
- .primary-btn {
14
- background-color: #3452db !important;
15
- }
16
-
17
- .primary-btn:hover {
18
- background-color: #2a41af !important;
19
- }
20
-
21
  .feature-card {
22
  border: 1px solid #e0e5ff;
23
  background-color: #ffffff;
@@ -29,25 +123,6 @@ body {
29
  box-shadow: 0 2px 4px rgba(52, 82, 219, 0.1);
30
  }
31
 
32
- .feature-card.selected {
33
- border: 2px solid #3452db;
34
- background-color: #eef1ff;
35
- }
36
-
37
- .show-more-btn {
38
- color: #3452db;
39
- font-weight: 600;
40
- }
41
-
42
- .show-more-btn:hover {
43
- color: #2a41af;
44
- }
45
-
46
- .token-header {
47
- color: #152156;
48
- font-weight: 700;
49
- }
50
-
51
  .dashboard-container {
52
  border: 1px solid #e0e5ff;
53
  border-radius: 8px;
@@ -55,7 +130,6 @@ body {
55
  }
56
  """
57
 
58
- # Create custom theme
59
  theme = gr.themes.Soft(
60
  primary_hue=gr.themes.colors.Color(
61
  name="blue",
@@ -73,229 +147,96 @@ theme = gr.themes.Soft(
73
  )
74
  )
75
 
76
- def get_features(text: str) -> Dict:
77
- """Get neural features from the API using the exact website parameters."""
78
- url = "https://www.neuronpedia.org/api/search-with-topk"
79
- payload = {
80
- "modelId": "gemma-2-2b",
81
- "text": text,
82
- "layer": "20-gemmascope-res-16k"
83
- }
84
-
85
- try:
86
- response = requests.post(
87
- url,
88
- headers={"Content-Type": "application/json"},
89
- json=payload
90
- )
91
- response.raise_for_status()
92
- return response.json()
93
- except Exception as e:
94
- return None
95
-
96
- def create_feature_html(feature_id: int, activation: float, selected: bool = False) -> str:
97
- """Create HTML for an individual feature card."""
98
- selected_class = "selected" if selected else ""
99
- return f"""
100
- <div class="feature-card {selected_class} p-4 rounded-lg mb-4"
101
- data-feature-id="{feature_id}"
102
- onclick="selectFeature(this, {feature_id}, {activation})">
103
- <div class="flex justify-between items-center">
104
- <div>
105
- <span class="font-semibold">Feature {feature_id}</span>
106
- <span class="ml-2 text-gray-600">(Activation: {activation:.2f})</span>
107
- </div>
108
- </div>
109
- </div>
110
- """
111
-
112
- def create_token_section(token: str, features: List[Dict], initial_count: int = 3) -> str:
113
- """Create HTML for a token section with its features."""
114
- features_html = "".join([
115
- create_feature_html(f['feature_index'], f['activation_value'])
116
- for f in features[:initial_count]
117
- ])
118
-
119
- show_more = ""
120
- if len(features) > initial_count:
121
- remaining = len(features) - initial_count
122
- hidden_features = "".join([
123
- create_feature_html(f['feature_index'], f['activation_value'])
124
- for f in features[initial_count:]
125
- ])
126
- show_more = f"""
127
- <div class="hidden" id="more-features-{token}">{hidden_features}</div>
128
- <button id="toggle-btn-{token}"
129
- class="show-more-btn text-sm mt-2"
130
- onclick="toggleFeatures('{token}')">
131
- Show {remaining} More Features
132
- </button>
133
- """
134
-
135
- return f"""
136
- <div class="mb-6">
137
- <h2 class="token-header text-xl mb-4">Token: {token}</h2>
138
- <div id="features-{token}">
139
- {features_html}
140
- </div>
141
- {show_more}
142
- </div>
143
- """
144
-
145
- def create_dashboard_html(feature_id: int, activation: float) -> str:
146
- """Create HTML for the feature dashboard."""
147
- return f"""
148
- <div class="dashboard-container p-4">
149
- <h3 class="text-lg font-semibold mb-4 text-gray-900">
150
- Feature {feature_id} Dashboard (Activation: {activation:.2f})
151
- </h3>
152
- <iframe
153
- src="https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{feature_id}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
154
- width="100%"
155
- height="600"
156
- frameborder="0"
157
- class="rounded-lg"
158
- ></iframe>
159
- </div>
160
- """
161
-
162
- def create_interface_html(data: Dict) -> str:
163
- """Create the complete interface HTML with JavaScript functionality."""
164
- js_code = """
165
- <script>
166
- function updateDashboard(featureId, activation) {
167
- const dashboardContainer = document.getElementById('dashboard-container');
168
- dashboardContainer.innerHTML = `
169
- <div class="dashboard-container p-4">
170
- <h3 class="text-lg font-semibold mb-4 text-gray-900">
171
- Feature ${featureId} Dashboard (Activation: ${activation.toFixed(2)})
172
- </h3>
173
- <iframe
174
- src="https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/${featureId}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
175
- width="100%"
176
- height="600"
177
- frameborder="0"
178
- class="rounded-lg"
179
- ></iframe>
180
- </div>
181
- `;
182
- }
183
-
184
- function selectFeature(element, featureId, activation) {
185
- // Update selected state visually
186
- document.querySelectorAll('.feature-card').forEach(card => {
187
- card.classList.remove('selected');
188
- });
189
- element.classList.add('selected');
190
-
191
- // Update dashboard
192
- updateDashboard(featureId, activation);
193
- }
194
 
195
- function toggleFeatures(token) {
196
- const moreFeatures = document.getElementById(`more-features-${token}`);
197
- const featuresContainer = document.getElementById(`features-${token}`);
198
- const toggleButton = document.getElementById(`toggle-btn-${token}`);
199
-
200
- if (moreFeatures.classList.contains('hidden')) {
201
- // Show additional features
202
- moreFeatures.classList.remove('hidden');
203
- const additionalFeatures = moreFeatures.innerHTML;
204
- featuresContainer.insertAdjacentHTML('beforeend', additionalFeatures);
205
- toggleButton.textContent = 'Show Less';
206
- } else {
207
- // Hide additional features
208
- const allFeatures = featuresContainer.querySelectorAll('.feature-card');
209
- Array.from(allFeatures).slice(3).forEach(card => card.remove());
210
- moreFeatures.classList.add('hidden');
211
- toggleButton.textContent = `Show ${moreFeatures.children.length} More Features`;
212
- }
213
  }
214
- </script>
215
- """
 
 
216
 
217
- tokens_html = ""
218
- dashboard_html = ""
219
- first_feature = None
 
 
 
 
 
220
 
221
- for result in data['results']:
222
- if result['token'] == '<bos>':
223
- continue
224
-
225
- tokens_html += create_token_section(result['token'], result['top_features'])
226
 
227
- if not first_feature and result['top_features']:
228
- first_feature = result['top_features'][0]
229
- dashboard_html = create_dashboard_html(
230
- first_feature['feature_index'],
231
- first_feature['activation_value']
232
- )
 
 
233
 
234
- return f"""
235
- <div class="p-6">
236
- {js_code}
237
- <div class="grid grid-cols-1 lg:grid-cols-2 gap-8">
238
- <div class="space-y-6">
239
- {tokens_html}
240
- </div>
241
- <div class="lg:sticky lg:top-6">
242
- <div id="dashboard-container">
243
- {dashboard_html}
244
- </div>
245
- </div>
246
- </div>
247
- </div>
248
- """
249
-
250
- def analyze_features(text: str) -> Tuple[str, str, str]:
251
- data = get_features(text)
252
- if not data:
253
- return "Error analyzing text", "", ""
254
 
255
- interface_html = create_interface_html(data)
256
- return interface_html, "", ""
257
 
258
  def create_interface():
 
 
259
  with gr.Blocks(theme=theme, css=css) as interface:
260
- gr.Markdown(
261
- "# Brand Feature Analyzer",
262
- elem_classes="text-2xl font-bold text-gray-900 mb-2"
263
- )
264
- gr.Markdown(
265
- "*Analyze your brand using Gemma's interpretable neural features*",
266
- elem_classes="text-gray-600 mb-6"
267
- )
268
 
269
  with gr.Row():
270
- with gr.Column():
271
  input_text = gr.Textbox(
272
  lines=5,
273
  placeholder="Enter text to analyze...",
274
- label="Input Text",
275
- elem_classes="mb-4"
276
- )
277
- analyze_btn = gr.Button(
278
- "Analyze Features",
279
- variant="primary",
280
- elem_classes="primary-btn"
281
  )
282
- # Examples without elem_classes
283
  gr.Examples(
284
  examples=["WordLift", "Think Different", "Just Do It"],
285
  inputs=input_text
286
  )
287
 
288
- with gr.Column():
289
- output_html = gr.HTML()
290
- feature_label = gr.Text(show_label=False, visible=False)
291
- dashboard = gr.HTML(visible=False)
292
 
 
293
  analyze_btn.click(
294
  fn=analyze_features,
295
- inputs=input_text,
296
- outputs=[output_html, feature_label, dashboard]
297
  )
298
-
299
  return interface
300
 
301
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import requests
3
  from typing import Dict, Tuple, List
4
+ import json
5
+ from dataclasses import dataclass
6
+ from typing import Optional
7
+
8
+ @dataclass
9
+ class Feature:
10
+ feature_id: int
11
+ activation: float
12
+ token: str
13
+ position: int
14
+
15
+ class FeatureState:
16
+ def __init__(self):
17
+ self.features_by_token = {}
18
+ self.expanded_tokens = set()
19
+ self.selected_feature = None
20
+
21
+ def get_features(text: str) -> Dict:
22
+ """Get neural features from the API using the exact website parameters."""
23
+ url = "https://www.neuronpedia.org/api/search-with-topk"
24
+ payload = {
25
+ "modelId": "gemma-2-2b",
26
+ "text": text,
27
+ "layer": "20-gemmascope-res-16k"
28
+ }
29
+
30
+ try:
31
+ response = requests.post(
32
+ url,
33
+ headers={"Content-Type": "application/json"},
34
+ json=payload
35
+ )
36
+ response.raise_for_status()
37
+ return response.json()
38
+ except Exception as e:
39
+ return None
40
+
41
+ def format_feature_list(features: List[Feature], token: str, expanded: bool = False) -> str:
42
+ """Format features as HTML list."""
43
+ display_features = features if expanded else features[:3]
44
+ features_html = ""
45
+
46
+ for feature in display_features:
47
+ features_html += f"""
48
+ <div class="feature-card p-4 rounded-lg mb-4 cursor-pointer hover:border-blue-500"
49
+ data-feature-id="{feature.feature_id}">
50
+ <div class="flex justify-between items-center">
51
+ <div>
52
+ <span class="font-semibold">Feature {feature.feature_id}</span>
53
+ <span class="ml-2 text-gray-600">(Activation: {feature.activation:.2f})</span>
54
+ </div>
55
+ </div>
56
+ </div>
57
+ """
58
+
59
+ if not expanded and len(features) > 3:
60
+ remaining = len(features) - 3
61
+ features_html += f"""
62
+ <div class="text-center">
63
+ <span class="text-blue-500 text-sm">{remaining} more features available</span>
64
+ </div>
65
+ """
66
+
67
+ return features_html
68
+
69
+ def format_dashboard(feature: Feature) -> str:
70
+ """Format the dashboard HTML for a selected feature."""
71
+ if not feature:
72
+ return ""
73
+
74
+ return f"""
75
+ <div class="dashboard-container p-4">
76
+ <h3 class="text-lg font-semibold mb-4 text-gray-900">
77
+ Feature {feature.feature_id} Dashboard (Activation: {feature.activation:.2f})
78
+ </h3>
79
+ <iframe
80
+ src="https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{feature.feature_id}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
81
+ width="100%"
82
+ height="600"
83
+ frameborder="0"
84
+ class="rounded-lg"
85
+ ></iframe>
86
+ </div>
87
+ """
88
+
89
+ def process_features(data: Dict) -> Dict[str, List[Feature]]:
90
+ """Process API response into features grouped by token."""
91
+ features_by_token = {}
92
+ for result in data.get('results', []):
93
+ if result['token'] == '<bos>':
94
+ continue
95
+
96
+ token = result['token']
97
+ features = []
98
+ for idx, feature in enumerate(result.get('top_features', [])):
99
+ features.append(Feature(
100
+ feature_id=feature['feature_index'],
101
+ activation=feature['activation_value'],
102
+ token=token,
103
+ position=idx
104
+ ))
105
+ features_by_token[token] = features
106
+ return features_by_token
107
 
 
108
  css = """
109
  @import url('https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;600;700&display=swap');
110
 
 
112
  font-family: 'Open Sans', sans-serif !important;
113
  }
114
 
 
 
 
 
 
 
 
 
115
  .feature-card {
116
  border: 1px solid #e0e5ff;
117
  background-color: #ffffff;
 
123
  box-shadow: 0 2px 4px rgba(52, 82, 219, 0.1);
124
  }
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  .dashboard-container {
127
  border: 1px solid #e0e5ff;
128
  border-radius: 8px;
 
130
  }
131
  """
132
 
 
133
  theme = gr.themes.Soft(
134
  primary_hue=gr.themes.colors.Color(
135
  name="blue",
 
147
  )
148
  )
149
 
150
+ def analyze_features(text: str, state: Optional[Dict] = None) -> Tuple[str, Dict]:
151
+ """Main analysis function that processes text and returns formatted output."""
152
+ if not text:
153
+ return "", None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
+ data = get_features(text)
156
+ if not data:
157
+ return "Error analyzing text", None
158
+
159
+ # Process features and build state
160
+ features_by_token = process_features(data)
161
+
162
+ # Initialize state if needed
163
+ if not state:
164
+ state = {
165
+ 'features_by_token': features_by_token,
166
+ 'expanded_tokens': set(),
167
+ 'selected_feature': None
 
 
 
 
 
168
  }
169
+ # Select first feature as default
170
+ first_token = next(iter(features_by_token))
171
+ if features_by_token[first_token]:
172
+ state['selected_feature'] = features_by_token[first_token][0]
173
 
174
+ # Build output HTML
175
+ output = []
176
+ for token, features in features_by_token.items():
177
+ expanded = token in state['expanded_tokens']
178
+ token_html = f"<h2 class='text-xl font-bold mb-4'>Token: {token}</h2>"
179
+ features_html = format_feature_list(features, token, expanded)
180
+
181
+ output.append(f"<div class='mb-6'>{token_html}{features_html}</div>")
182
 
183
+ # Add dashboard if a feature is selected
184
+ if state['selected_feature']:
185
+ output.append(format_dashboard(state['selected_feature']))
 
 
186
 
187
+ return "\n".join(output), state
188
+
189
+ def toggle_expansion(token: str, state: Dict) -> Tuple[str, Dict]:
190
+ """Toggle expansion state for a token's features."""
191
+ if token in state['expanded_tokens']:
192
+ state['expanded_tokens'].remove(token)
193
+ else:
194
+ state['expanded_tokens'].add(token)
195
 
196
+ output_html, state = analyze_features(None, state)
197
+ return output_html, state
198
+
199
+ def select_feature(feature_id: int, state: Dict) -> Tuple[str, Dict]:
200
+ """Select a feature and update the dashboard."""
201
+ for features in state['features_by_token'].values():
202
+ for feature in features:
203
+ if feature.feature_id == feature_id:
204
+ state['selected_feature'] = feature
205
+ break
 
 
 
 
 
 
 
 
 
 
206
 
207
+ output_html, state = analyze_features(None, state)
208
+ return output_html, state
209
 
210
  def create_interface():
211
+ state = gr.State({})
212
+
213
  with gr.Blocks(theme=theme, css=css) as interface:
214
+ gr.Markdown("# Neural Feature Analyzer", elem_classes="text-2xl font-bold mb-2")
215
+ gr.Markdown("*Analyze text using Gemma's interpretable neural features*", elem_classes="text-gray-600 mb-6")
 
 
 
 
 
 
216
 
217
  with gr.Row():
218
+ with gr.Column(scale=1):
219
  input_text = gr.Textbox(
220
  lines=5,
221
  placeholder="Enter text to analyze...",
222
+ label="Input Text"
 
 
 
 
 
 
223
  )
224
+ analyze_btn = gr.Button("Analyze Features", variant="primary")
225
  gr.Examples(
226
  examples=["WordLift", "Think Different", "Just Do It"],
227
  inputs=input_text
228
  )
229
 
230
+ with gr.Column(scale=2):
231
+ output = gr.HTML()
 
 
232
 
233
+ # Event handlers
234
  analyze_btn.click(
235
  fn=analyze_features,
236
+ inputs=[input_text, state],
237
+ outputs=[output, state]
238
  )
239
+
240
  return interface
241
 
242
  if __name__ == "__main__":