cyberandy commited on
Commit
321a1b2
·
verified ·
1 Parent(s): c3f5f94

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -89
app.py CHANGED
@@ -1,9 +1,10 @@
1
  import gradio as gr
2
  import requests
3
- from typing import Dict, Tuple, List
 
4
 
5
  def get_features(text: str) -> Dict:
6
- """Get neural features from the API using the exact website parameters."""
7
  url = "https://www.neuronpedia.org/api/search-with-topk"
8
  payload = {
9
  "modelId": "gemma-2-2b",
@@ -22,89 +23,127 @@ def get_features(text: str) -> Dict:
22
  except Exception as e:
23
  return None
24
 
25
- def format_feature_list(token: str, features: List[Dict], show_all: bool = False) -> str:
26
- """Format features as HTML list."""
27
- feature_count = len(features) if show_all else min(3, len(features))
28
- features_html = ""
29
-
30
- for idx in range(feature_count):
31
- feature = features[idx]
32
- feature_id = feature['feature_index']
33
- activation = feature['activation_value']
34
-
35
- features_html += f"""
36
- <div class="feature-card p-4 rounded-lg mb-4 hover:border-blue-500">
37
- <div class="flex justify-between items-center">
38
- <div>
39
- <span class="font-semibold">Feature {feature_id}</span>
40
- <span class="ml-2 text-gray-600">(Activation: {activation:.2f})</span>
41
- </div>
42
- <a href="https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{feature_id}"
43
- target="_blank"
44
- class="text-blue-600 hover:text-blue-800">
45
- View on Neuronpedia →
46
- </a>
47
- </div>
48
- </div>
49
- """
50
-
51
- # Add dashboard for first feature only
52
- if idx == 0:
53
- features_html += f"""
54
- <div class="dashboard-container mb-6 p-4">
55
- <h3 class="text-lg font-semibold mb-4">Feature {feature_id} Dashboard</h3>
56
- <iframe
57
- src="https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{feature_id}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
58
- width="100%"
59
- height="600"
60
- frameborder="0"
61
- class="rounded-lg"
62
- ></iframe>
63
- </div>
64
- """
65
-
66
- remaining = len(features) - 3
67
- if not show_all and remaining > 0:
68
- features_html += f"""
69
- <div class="text-sm text-gray-600 mb-4">
70
- {remaining} more features available.
71
- <a href="https://www.neuronpedia.org/gemma-2-2b" target="_blank" class="text-blue-600 hover:text-blue-800">
72
- View all on Neuronpedia →
73
- </a>
74
- </div>
75
- """
76
-
77
- return features_html
78
-
79
- def analyze_features(text: str) -> str:
80
- """Main analysis function that processes text and returns formatted output."""
81
- if not text:
82
  return ""
83
-
84
- data = get_features(text)
85
- if not data:
86
- return "Error analyzing text"
87
 
88
  output = ['<div class="p-6">']
89
 
90
  # Process each token's features
91
- for result in data['results']:
92
  if result['token'] == '<bos>':
93
  continue
94
-
95
  token = result['token']
96
  features = result['top_features']
 
 
97
 
98
- output.append(f"""
99
- <div class="mb-8">
100
- <h2 class="text-xl font-bold mb-4">Token: {token}</h2>
101
- {format_feature_list(token, features)}
102
- </div>
103
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  output.append('</div>')
106
  return "\n".join(output)
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  css = """
109
  @import url('https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;600;700&display=swap');
110
 
@@ -127,30 +166,26 @@ body {
127
  border-radius: 8px;
128
  background-color: #ffffff;
129
  }
130
-
131
- .hljs {
132
- background: #f5f7ff !important;
133
- }
134
  """
135
 
136
  theme = gr.themes.Soft(
137
  primary_hue=gr.themes.colors.Color(
138
  name="blue",
139
- c50="#eef1ff",
140
- c100="#e0e5ff",
141
- c200="#c3cbff",
142
- c300="#a5b2ff",
143
- c400="#8798ff",
144
- c500="#6a7eff",
145
- c600="#3452db",
146
- c700="#2a41af",
147
- c800="#1f3183",
148
- c900="#152156",
149
- c950="#0a102b",
150
  )
151
  )
152
 
153
  def create_interface():
 
 
 
 
 
 
 
154
  with gr.Blocks(theme=theme, css=css) as interface:
155
  gr.Markdown("# Neural Feature Analyzer", elem_classes="text-2xl font-bold mb-2")
156
  gr.Markdown("*Analyze text using Gemma's interpretable neural features*", elem_classes="text-gray-600 mb-6")
@@ -169,13 +204,22 @@ def create_interface():
169
  )
170
 
171
  with gr.Column(scale=2):
172
- output = gr.HTML()
 
173
 
 
174
  analyze_btn.click(
175
  fn=analyze_features,
176
- inputs=input_text,
177
- outputs=output
178
  )
 
 
 
 
 
 
 
179
 
180
  return interface
181
 
 
1
  import gradio as gr
2
  import requests
3
+ from typing import Dict, List, Tuple
4
+ import json
5
 
6
  def get_features(text: str) -> Dict:
7
+ """Get neural features from the API."""
8
  url = "https://www.neuronpedia.org/api/search-with-topk"
9
  payload = {
10
  "modelId": "gemma-2-2b",
 
23
  except Exception as e:
24
  return None
25
 
26
+ def format_features(features_data: Dict, expanded_tokens: List[str], selected_feature: Dict) -> str:
27
+ """Format features as HTML with expanded state."""
28
+ if not features_data or 'results' not in features_data:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  return ""
 
 
 
 
30
 
31
  output = ['<div class="p-6">']
32
 
33
  # Process each token's features
34
+ for result in features_data['results']:
35
  if result['token'] == '<bos>':
36
  continue
37
+
38
  token = result['token']
39
  features = result['top_features']
40
+ is_expanded = token in expanded_tokens
41
+ feature_count = len(features) if is_expanded else min(3, len(features))
42
 
43
+ output.append(f'<div class="mb-8"><h2 class="text-xl font-bold mb-4">Token: {token}</h2>')
44
+
45
+ # Display features
46
+ for idx in range(feature_count):
47
+ feature = features[idx]
48
+ feature_id = feature['feature_index']
49
+ activation = feature['activation_value']
50
+ is_selected = selected_feature and selected_feature.get('feature_id') == feature_id
51
+
52
+ selected_class = "border-blue-500 border-2" if is_selected else ""
53
+
54
+ output.append(f"""
55
+ <div class="feature-card p-4 rounded-lg mb-4 hover:border-blue-500 {selected_class}">
56
+ <div class="flex justify-between items-center">
57
+ <div>
58
+ <span class="font-semibold">Feature {feature_id}</span>
59
+ <span class="ml-2 text-gray-600">(Activation: {activation:.2f})</span>
60
+ </div>
61
+ </div>
62
+ </div>
63
+ """)
64
+
65
+ # Show more/less button if needed
66
+ if len(features) > 3:
67
+ action = "less" if is_expanded else f"{len(features) - 3} more"
68
+ output.append(f"""
69
+ <div class="text-center mb-4">
70
+ <button class="text-blue-600 hover:text-blue-800 text-sm"
71
+ onclick="gradio('toggle_expansion', '{token}')">
72
+ Show {action} features
73
+ </button>
74
+ </div>
75
+ """)
76
+
77
+ output.append('</div>')
78
 
79
  output.append('</div>')
80
  return "\n".join(output)
81
 
82
+ def format_dashboard(feature: Dict) -> str:
83
+ """Format the feature dashboard."""
84
+ if not feature:
85
+ return ""
86
+
87
+ feature_id = feature['feature_id']
88
+ activation = feature['activation']
89
+
90
+ return f"""
91
+ <div class="dashboard-container p-4">
92
+ <h3 class="text-lg font-semibold mb-4">
93
+ Feature {feature_id} Dashboard (Activation: {activation:.2f})
94
+ </h3>
95
+ <iframe
96
+ src="https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{feature_id}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
97
+ width="100%"
98
+ height="600"
99
+ frameborder="0"
100
+ class="rounded-lg"
101
+ ></iframe>
102
+ </div>
103
+ """
104
+
105
+ def analyze_features(text: str, state: Dict) -> Tuple[str, str, Dict]:
106
+ """Process text and update state."""
107
+ if not text:
108
+ return "", "", state
109
+
110
+ features_data = get_features(text)
111
+ if not features_data:
112
+ return "Error analyzing text", "", state
113
+
114
+ # Update state
115
+ state['features_data'] = features_data
116
+ if not state.get('expanded_tokens'):
117
+ state['expanded_tokens'] = []
118
+
119
+ # Select first feature by default if none selected
120
+ if not state.get('selected_feature'):
121
+ for result in features_data['results']:
122
+ if result['token'] != '<bos>' and result['top_features']:
123
+ first_feature = result['top_features'][0]
124
+ state['selected_feature'] = {
125
+ 'feature_id': first_feature['feature_index'],
126
+ 'activation': first_feature['activation_value']
127
+ }
128
+ break
129
+
130
+ features_html = format_features(features_data, state['expanded_tokens'], state['selected_feature'])
131
+ dashboard_html = format_dashboard(state['selected_feature'])
132
+
133
+ return features_html, dashboard_html, state
134
+
135
+ def toggle_expansion(token: str, state: Dict) -> Tuple[str, str, Dict]:
136
+ """Toggle expansion state for a token."""
137
+ if token in state['expanded_tokens']:
138
+ state['expanded_tokens'].remove(token)
139
+ else:
140
+ state['expanded_tokens'].append(token)
141
+
142
+ features_html = format_features(state['features_data'], state['expanded_tokens'], state['selected_feature'])
143
+ dashboard_html = format_dashboard(state['selected_feature'])
144
+
145
+ return features_html, dashboard_html, state
146
+
147
  css = """
148
  @import url('https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;600;700&display=swap');
149
 
 
166
  border-radius: 8px;
167
  background-color: #ffffff;
168
  }
 
 
 
 
169
  """
170
 
171
  theme = gr.themes.Soft(
172
  primary_hue=gr.themes.colors.Color(
173
  name="blue",
174
+ c50="#eef1ff", c100="#e0e5ff", c200="#c3cbff",
175
+ c300="#a5b2ff", c400="#8798ff", c500="#6a7eff",
176
+ c600="#3452db", c700="#2a41af", c800="#1f3183",
177
+ c900="#152156", c950="#0a102b",
 
 
 
 
 
 
 
178
  )
179
  )
180
 
181
  def create_interface():
182
+ # Initialize state
183
+ state = gr.State({
184
+ 'features_data': None,
185
+ 'expanded_tokens': [],
186
+ 'selected_feature': None
187
+ })
188
+
189
  with gr.Blocks(theme=theme, css=css) as interface:
190
  gr.Markdown("# Neural Feature Analyzer", elem_classes="text-2xl font-bold mb-2")
191
  gr.Markdown("*Analyze text using Gemma's interpretable neural features*", elem_classes="text-gray-600 mb-6")
 
204
  )
205
 
206
  with gr.Column(scale=2):
207
+ features_html = gr.HTML()
208
+ dashboard_html = gr.HTML()
209
 
210
+ # Event handlers
211
  analyze_btn.click(
212
  fn=analyze_features,
213
+ inputs=[input_text, state],
214
+ outputs=[features_html, dashboard_html, state]
215
  )
216
+
217
+ # Custom JavaScript function for token expansion
218
+ interface.load(None, None, None, _js="""
219
+ function toggle_expansion(token) {
220
+ // Function will be called from HTML onclick
221
+ }
222
+ """)
223
 
224
  return interface
225