Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
import gradio as gr
|
2 |
import requests
|
3 |
-
from typing import Dict,
|
|
|
4 |
|
5 |
def get_features(text: str) -> Dict:
|
6 |
-
"""Get neural features from the API
|
7 |
url = "https://www.neuronpedia.org/api/search-with-topk"
|
8 |
payload = {
|
9 |
"modelId": "gemma-2-2b",
|
@@ -22,89 +23,127 @@ def get_features(text: str) -> Dict:
|
|
22 |
except Exception as e:
|
23 |
return None
|
24 |
|
25 |
-
def
|
26 |
-
"""Format features as HTML
|
27 |
-
|
28 |
-
features_html = ""
|
29 |
-
|
30 |
-
for idx in range(feature_count):
|
31 |
-
feature = features[idx]
|
32 |
-
feature_id = feature['feature_index']
|
33 |
-
activation = feature['activation_value']
|
34 |
-
|
35 |
-
features_html += f"""
|
36 |
-
<div class="feature-card p-4 rounded-lg mb-4 hover:border-blue-500">
|
37 |
-
<div class="flex justify-between items-center">
|
38 |
-
<div>
|
39 |
-
<span class="font-semibold">Feature {feature_id}</span>
|
40 |
-
<span class="ml-2 text-gray-600">(Activation: {activation:.2f})</span>
|
41 |
-
</div>
|
42 |
-
<a href="https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{feature_id}"
|
43 |
-
target="_blank"
|
44 |
-
class="text-blue-600 hover:text-blue-800">
|
45 |
-
View on Neuronpedia →
|
46 |
-
</a>
|
47 |
-
</div>
|
48 |
-
</div>
|
49 |
-
"""
|
50 |
-
|
51 |
-
# Add dashboard for first feature only
|
52 |
-
if idx == 0:
|
53 |
-
features_html += f"""
|
54 |
-
<div class="dashboard-container mb-6 p-4">
|
55 |
-
<h3 class="text-lg font-semibold mb-4">Feature {feature_id} Dashboard</h3>
|
56 |
-
<iframe
|
57 |
-
src="https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{feature_id}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
|
58 |
-
width="100%"
|
59 |
-
height="600"
|
60 |
-
frameborder="0"
|
61 |
-
class="rounded-lg"
|
62 |
-
></iframe>
|
63 |
-
</div>
|
64 |
-
"""
|
65 |
-
|
66 |
-
remaining = len(features) - 3
|
67 |
-
if not show_all and remaining > 0:
|
68 |
-
features_html += f"""
|
69 |
-
<div class="text-sm text-gray-600 mb-4">
|
70 |
-
{remaining} more features available.
|
71 |
-
<a href="https://www.neuronpedia.org/gemma-2-2b" target="_blank" class="text-blue-600 hover:text-blue-800">
|
72 |
-
View all on Neuronpedia →
|
73 |
-
</a>
|
74 |
-
</div>
|
75 |
-
"""
|
76 |
-
|
77 |
-
return features_html
|
78 |
-
|
79 |
-
def analyze_features(text: str) -> str:
|
80 |
-
"""Main analysis function that processes text and returns formatted output."""
|
81 |
-
if not text:
|
82 |
return ""
|
83 |
-
|
84 |
-
data = get_features(text)
|
85 |
-
if not data:
|
86 |
-
return "Error analyzing text"
|
87 |
|
88 |
output = ['<div class="p-6">']
|
89 |
|
90 |
# Process each token's features
|
91 |
-
for result in
|
92 |
if result['token'] == '<bos>':
|
93 |
continue
|
94 |
-
|
95 |
token = result['token']
|
96 |
features = result['top_features']
|
|
|
|
|
97 |
|
98 |
-
output.append(f"""
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
output.append('</div>')
|
106 |
return "\n".join(output)
|
107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
css = """
|
109 |
@import url('https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;600;700&display=swap');
|
110 |
|
@@ -127,30 +166,26 @@ body {
|
|
127 |
border-radius: 8px;
|
128 |
background-color: #ffffff;
|
129 |
}
|
130 |
-
|
131 |
-
.hljs {
|
132 |
-
background: #f5f7ff !important;
|
133 |
-
}
|
134 |
"""
|
135 |
|
136 |
theme = gr.themes.Soft(
|
137 |
primary_hue=gr.themes.colors.Color(
|
138 |
name="blue",
|
139 |
-
c50="#eef1ff",
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
c400="#8798ff",
|
144 |
-
c500="#6a7eff",
|
145 |
-
c600="#3452db",
|
146 |
-
c700="#2a41af",
|
147 |
-
c800="#1f3183",
|
148 |
-
c900="#152156",
|
149 |
-
c950="#0a102b",
|
150 |
)
|
151 |
)
|
152 |
|
153 |
def create_interface():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
with gr.Blocks(theme=theme, css=css) as interface:
|
155 |
gr.Markdown("# Neural Feature Analyzer", elem_classes="text-2xl font-bold mb-2")
|
156 |
gr.Markdown("*Analyze text using Gemma's interpretable neural features*", elem_classes="text-gray-600 mb-6")
|
@@ -169,13 +204,22 @@ def create_interface():
|
|
169 |
)
|
170 |
|
171 |
with gr.Column(scale=2):
|
172 |
-
|
|
|
173 |
|
|
|
174 |
analyze_btn.click(
|
175 |
fn=analyze_features,
|
176 |
-
inputs=input_text,
|
177 |
-
outputs=
|
178 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
|
180 |
return interface
|
181 |
|
|
|
1 |
import gradio as gr
|
2 |
import requests
|
3 |
+
from typing import Dict, List, Tuple
|
4 |
+
import json
|
5 |
|
6 |
def get_features(text: str) -> Dict:
|
7 |
+
"""Get neural features from the API."""
|
8 |
url = "https://www.neuronpedia.org/api/search-with-topk"
|
9 |
payload = {
|
10 |
"modelId": "gemma-2-2b",
|
|
|
23 |
except Exception as e:
|
24 |
return None
|
25 |
|
26 |
+
def format_features(features_data: Dict, expanded_tokens: List[str], selected_feature: Dict) -> str:
|
27 |
+
"""Format features as HTML with expanded state."""
|
28 |
+
if not features_data or 'results' not in features_data:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
return ""
|
|
|
|
|
|
|
|
|
30 |
|
31 |
output = ['<div class="p-6">']
|
32 |
|
33 |
# Process each token's features
|
34 |
+
for result in features_data['results']:
|
35 |
if result['token'] == '<bos>':
|
36 |
continue
|
37 |
+
|
38 |
token = result['token']
|
39 |
features = result['top_features']
|
40 |
+
is_expanded = token in expanded_tokens
|
41 |
+
feature_count = len(features) if is_expanded else min(3, len(features))
|
42 |
|
43 |
+
output.append(f'<div class="mb-8"><h2 class="text-xl font-bold mb-4">Token: {token}</h2>')
|
44 |
+
|
45 |
+
# Display features
|
46 |
+
for idx in range(feature_count):
|
47 |
+
feature = features[idx]
|
48 |
+
feature_id = feature['feature_index']
|
49 |
+
activation = feature['activation_value']
|
50 |
+
is_selected = selected_feature and selected_feature.get('feature_id') == feature_id
|
51 |
+
|
52 |
+
selected_class = "border-blue-500 border-2" if is_selected else ""
|
53 |
+
|
54 |
+
output.append(f"""
|
55 |
+
<div class="feature-card p-4 rounded-lg mb-4 hover:border-blue-500 {selected_class}">
|
56 |
+
<div class="flex justify-between items-center">
|
57 |
+
<div>
|
58 |
+
<span class="font-semibold">Feature {feature_id}</span>
|
59 |
+
<span class="ml-2 text-gray-600">(Activation: {activation:.2f})</span>
|
60 |
+
</div>
|
61 |
+
</div>
|
62 |
+
</div>
|
63 |
+
""")
|
64 |
+
|
65 |
+
# Show more/less button if needed
|
66 |
+
if len(features) > 3:
|
67 |
+
action = "less" if is_expanded else f"{len(features) - 3} more"
|
68 |
+
output.append(f"""
|
69 |
+
<div class="text-center mb-4">
|
70 |
+
<button class="text-blue-600 hover:text-blue-800 text-sm"
|
71 |
+
onclick="gradio('toggle_expansion', '{token}')">
|
72 |
+
Show {action} features
|
73 |
+
</button>
|
74 |
+
</div>
|
75 |
+
""")
|
76 |
+
|
77 |
+
output.append('</div>')
|
78 |
|
79 |
output.append('</div>')
|
80 |
return "\n".join(output)
|
81 |
|
82 |
+
def format_dashboard(feature: Dict) -> str:
|
83 |
+
"""Format the feature dashboard."""
|
84 |
+
if not feature:
|
85 |
+
return ""
|
86 |
+
|
87 |
+
feature_id = feature['feature_id']
|
88 |
+
activation = feature['activation']
|
89 |
+
|
90 |
+
return f"""
|
91 |
+
<div class="dashboard-container p-4">
|
92 |
+
<h3 class="text-lg font-semibold mb-4">
|
93 |
+
Feature {feature_id} Dashboard (Activation: {activation:.2f})
|
94 |
+
</h3>
|
95 |
+
<iframe
|
96 |
+
src="https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{feature_id}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
|
97 |
+
width="100%"
|
98 |
+
height="600"
|
99 |
+
frameborder="0"
|
100 |
+
class="rounded-lg"
|
101 |
+
></iframe>
|
102 |
+
</div>
|
103 |
+
"""
|
104 |
+
|
105 |
+
def analyze_features(text: str, state: Dict) -> Tuple[str, str, Dict]:
|
106 |
+
"""Process text and update state."""
|
107 |
+
if not text:
|
108 |
+
return "", "", state
|
109 |
+
|
110 |
+
features_data = get_features(text)
|
111 |
+
if not features_data:
|
112 |
+
return "Error analyzing text", "", state
|
113 |
+
|
114 |
+
# Update state
|
115 |
+
state['features_data'] = features_data
|
116 |
+
if not state.get('expanded_tokens'):
|
117 |
+
state['expanded_tokens'] = []
|
118 |
+
|
119 |
+
# Select first feature by default if none selected
|
120 |
+
if not state.get('selected_feature'):
|
121 |
+
for result in features_data['results']:
|
122 |
+
if result['token'] != '<bos>' and result['top_features']:
|
123 |
+
first_feature = result['top_features'][0]
|
124 |
+
state['selected_feature'] = {
|
125 |
+
'feature_id': first_feature['feature_index'],
|
126 |
+
'activation': first_feature['activation_value']
|
127 |
+
}
|
128 |
+
break
|
129 |
+
|
130 |
+
features_html = format_features(features_data, state['expanded_tokens'], state['selected_feature'])
|
131 |
+
dashboard_html = format_dashboard(state['selected_feature'])
|
132 |
+
|
133 |
+
return features_html, dashboard_html, state
|
134 |
+
|
135 |
+
def toggle_expansion(token: str, state: Dict) -> Tuple[str, str, Dict]:
|
136 |
+
"""Toggle expansion state for a token."""
|
137 |
+
if token in state['expanded_tokens']:
|
138 |
+
state['expanded_tokens'].remove(token)
|
139 |
+
else:
|
140 |
+
state['expanded_tokens'].append(token)
|
141 |
+
|
142 |
+
features_html = format_features(state['features_data'], state['expanded_tokens'], state['selected_feature'])
|
143 |
+
dashboard_html = format_dashboard(state['selected_feature'])
|
144 |
+
|
145 |
+
return features_html, dashboard_html, state
|
146 |
+
|
147 |
css = """
|
148 |
@import url('https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;600;700&display=swap');
|
149 |
|
|
|
166 |
border-radius: 8px;
|
167 |
background-color: #ffffff;
|
168 |
}
|
|
|
|
|
|
|
|
|
169 |
"""
|
170 |
|
171 |
theme = gr.themes.Soft(
|
172 |
primary_hue=gr.themes.colors.Color(
|
173 |
name="blue",
|
174 |
+
c50="#eef1ff", c100="#e0e5ff", c200="#c3cbff",
|
175 |
+
c300="#a5b2ff", c400="#8798ff", c500="#6a7eff",
|
176 |
+
c600="#3452db", c700="#2a41af", c800="#1f3183",
|
177 |
+
c900="#152156", c950="#0a102b",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
)
|
179 |
)
|
180 |
|
181 |
def create_interface():
|
182 |
+
# Initialize state
|
183 |
+
state = gr.State({
|
184 |
+
'features_data': None,
|
185 |
+
'expanded_tokens': [],
|
186 |
+
'selected_feature': None
|
187 |
+
})
|
188 |
+
|
189 |
with gr.Blocks(theme=theme, css=css) as interface:
|
190 |
gr.Markdown("# Neural Feature Analyzer", elem_classes="text-2xl font-bold mb-2")
|
191 |
gr.Markdown("*Analyze text using Gemma's interpretable neural features*", elem_classes="text-gray-600 mb-6")
|
|
|
204 |
)
|
205 |
|
206 |
with gr.Column(scale=2):
|
207 |
+
features_html = gr.HTML()
|
208 |
+
dashboard_html = gr.HTML()
|
209 |
|
210 |
+
# Event handlers
|
211 |
analyze_btn.click(
|
212 |
fn=analyze_features,
|
213 |
+
inputs=[input_text, state],
|
214 |
+
outputs=[features_html, dashboard_html, state]
|
215 |
)
|
216 |
+
|
217 |
+
# Custom JavaScript function for token expansion
|
218 |
+
interface.load(None, None, None, _js="""
|
219 |
+
function toggle_expansion(token) {
|
220 |
+
// Function will be called from HTML onclick
|
221 |
+
}
|
222 |
+
""")
|
223 |
|
224 |
return interface
|
225 |
|