Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,8 +1,110 @@
|
|
1 |
import gradio as gr
|
2 |
import requests
|
3 |
from typing import Dict, Tuple, List
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
-
# Define custom CSS with Open Sans font and color theme
|
6 |
css = """
|
7 |
@import url('https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;600;700&display=swap');
|
8 |
|
@@ -10,14 +112,6 @@ body {
|
|
10 |
font-family: 'Open Sans', sans-serif !important;
|
11 |
}
|
12 |
|
13 |
-
.primary-btn {
|
14 |
-
background-color: #3452db !important;
|
15 |
-
}
|
16 |
-
|
17 |
-
.primary-btn:hover {
|
18 |
-
background-color: #2a41af !important;
|
19 |
-
}
|
20 |
-
|
21 |
.feature-card {
|
22 |
border: 1px solid #e0e5ff;
|
23 |
background-color: #ffffff;
|
@@ -29,25 +123,6 @@ body {
|
|
29 |
box-shadow: 0 2px 4px rgba(52, 82, 219, 0.1);
|
30 |
}
|
31 |
|
32 |
-
.feature-card.selected {
|
33 |
-
border: 2px solid #3452db;
|
34 |
-
background-color: #eef1ff;
|
35 |
-
}
|
36 |
-
|
37 |
-
.show-more-btn {
|
38 |
-
color: #3452db;
|
39 |
-
font-weight: 600;
|
40 |
-
}
|
41 |
-
|
42 |
-
.show-more-btn:hover {
|
43 |
-
color: #2a41af;
|
44 |
-
}
|
45 |
-
|
46 |
-
.token-header {
|
47 |
-
color: #152156;
|
48 |
-
font-weight: 700;
|
49 |
-
}
|
50 |
-
|
51 |
.dashboard-container {
|
52 |
border: 1px solid #e0e5ff;
|
53 |
border-radius: 8px;
|
@@ -55,7 +130,6 @@ body {
|
|
55 |
}
|
56 |
"""
|
57 |
|
58 |
-
# Create custom theme
|
59 |
theme = gr.themes.Soft(
|
60 |
primary_hue=gr.themes.colors.Color(
|
61 |
name="blue",
|
@@ -73,229 +147,96 @@ theme = gr.themes.Soft(
|
|
73 |
)
|
74 |
)
|
75 |
|
76 |
-
def
|
77 |
-
"""
|
78 |
-
|
79 |
-
|
80 |
-
"modelId": "gemma-2-2b",
|
81 |
-
"text": text,
|
82 |
-
"layer": "20-gemmascope-res-16k"
|
83 |
-
}
|
84 |
-
|
85 |
-
try:
|
86 |
-
response = requests.post(
|
87 |
-
url,
|
88 |
-
headers={"Content-Type": "application/json"},
|
89 |
-
json=payload
|
90 |
-
)
|
91 |
-
response.raise_for_status()
|
92 |
-
return response.json()
|
93 |
-
except Exception as e:
|
94 |
-
return None
|
95 |
-
|
96 |
-
def create_feature_html(feature_id: int, activation: float, selected: bool = False) -> str:
|
97 |
-
"""Create HTML for an individual feature card."""
|
98 |
-
selected_class = "selected" if selected else ""
|
99 |
-
return f"""
|
100 |
-
<div class="feature-card {selected_class} p-4 rounded-lg mb-4"
|
101 |
-
data-feature-id="{feature_id}"
|
102 |
-
onclick="selectFeature(this, {feature_id}, {activation})">
|
103 |
-
<div class="flex justify-between items-center">
|
104 |
-
<div>
|
105 |
-
<span class="font-semibold">Feature {feature_id}</span>
|
106 |
-
<span class="ml-2 text-gray-600">(Activation: {activation:.2f})</span>
|
107 |
-
</div>
|
108 |
-
</div>
|
109 |
-
</div>
|
110 |
-
"""
|
111 |
-
|
112 |
-
def create_token_section(token: str, features: List[Dict], initial_count: int = 3) -> str:
|
113 |
-
"""Create HTML for a token section with its features."""
|
114 |
-
features_html = "".join([
|
115 |
-
create_feature_html(f['feature_index'], f['activation_value'])
|
116 |
-
for f in features[:initial_count]
|
117 |
-
])
|
118 |
-
|
119 |
-
show_more = ""
|
120 |
-
if len(features) > initial_count:
|
121 |
-
remaining = len(features) - initial_count
|
122 |
-
hidden_features = "".join([
|
123 |
-
create_feature_html(f['feature_index'], f['activation_value'])
|
124 |
-
for f in features[initial_count:]
|
125 |
-
])
|
126 |
-
show_more = f"""
|
127 |
-
<div class="hidden" id="more-features-{token}">{hidden_features}</div>
|
128 |
-
<button id="toggle-btn-{token}"
|
129 |
-
class="show-more-btn text-sm mt-2"
|
130 |
-
onclick="toggleFeatures('{token}')">
|
131 |
-
Show {remaining} More Features
|
132 |
-
</button>
|
133 |
-
"""
|
134 |
-
|
135 |
-
return f"""
|
136 |
-
<div class="mb-6">
|
137 |
-
<h2 class="token-header text-xl mb-4">Token: {token}</h2>
|
138 |
-
<div id="features-{token}">
|
139 |
-
{features_html}
|
140 |
-
</div>
|
141 |
-
{show_more}
|
142 |
-
</div>
|
143 |
-
"""
|
144 |
-
|
145 |
-
def create_dashboard_html(feature_id: int, activation: float) -> str:
|
146 |
-
"""Create HTML for the feature dashboard."""
|
147 |
-
return f"""
|
148 |
-
<div class="dashboard-container p-4">
|
149 |
-
<h3 class="text-lg font-semibold mb-4 text-gray-900">
|
150 |
-
Feature {feature_id} Dashboard (Activation: {activation:.2f})
|
151 |
-
</h3>
|
152 |
-
<iframe
|
153 |
-
src="https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{feature_id}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
|
154 |
-
width="100%"
|
155 |
-
height="600"
|
156 |
-
frameborder="0"
|
157 |
-
class="rounded-lg"
|
158 |
-
></iframe>
|
159 |
-
</div>
|
160 |
-
"""
|
161 |
-
|
162 |
-
def create_interface_html(data: Dict) -> str:
|
163 |
-
"""Create the complete interface HTML with JavaScript functionality."""
|
164 |
-
js_code = """
|
165 |
-
<script>
|
166 |
-
function updateDashboard(featureId, activation) {
|
167 |
-
const dashboardContainer = document.getElementById('dashboard-container');
|
168 |
-
dashboardContainer.innerHTML = `
|
169 |
-
<div class="dashboard-container p-4">
|
170 |
-
<h3 class="text-lg font-semibold mb-4 text-gray-900">
|
171 |
-
Feature ${featureId} Dashboard (Activation: ${activation.toFixed(2)})
|
172 |
-
</h3>
|
173 |
-
<iframe
|
174 |
-
src="https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/${featureId}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
|
175 |
-
width="100%"
|
176 |
-
height="600"
|
177 |
-
frameborder="0"
|
178 |
-
class="rounded-lg"
|
179 |
-
></iframe>
|
180 |
-
</div>
|
181 |
-
`;
|
182 |
-
}
|
183 |
-
|
184 |
-
function selectFeature(element, featureId, activation) {
|
185 |
-
// Update selected state visually
|
186 |
-
document.querySelectorAll('.feature-card').forEach(card => {
|
187 |
-
card.classList.remove('selected');
|
188 |
-
});
|
189 |
-
element.classList.add('selected');
|
190 |
-
|
191 |
-
// Update dashboard
|
192 |
-
updateDashboard(featureId, activation);
|
193 |
-
}
|
194 |
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
const allFeatures = featuresContainer.querySelectorAll('.feature-card');
|
209 |
-
Array.from(allFeatures).slice(3).forEach(card => card.remove());
|
210 |
-
moreFeatures.classList.add('hidden');
|
211 |
-
toggleButton.textContent = `Show ${moreFeatures.children.length} More Features`;
|
212 |
-
}
|
213 |
}
|
214 |
-
|
215 |
-
|
|
|
|
|
216 |
|
217 |
-
|
218 |
-
|
219 |
-
|
|
|
|
|
|
|
|
|
|
|
220 |
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
tokens_html += create_token_section(result['token'], result['top_features'])
|
226 |
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
|
|
|
|
233 |
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
</div>
|
245 |
-
</div>
|
246 |
-
</div>
|
247 |
-
</div>
|
248 |
-
"""
|
249 |
-
|
250 |
-
def analyze_features(text: str) -> Tuple[str, str, str]:
|
251 |
-
data = get_features(text)
|
252 |
-
if not data:
|
253 |
-
return "Error analyzing text", "", ""
|
254 |
|
255 |
-
|
256 |
-
return
|
257 |
|
258 |
def create_interface():
|
|
|
|
|
259 |
with gr.Blocks(theme=theme, css=css) as interface:
|
260 |
-
gr.Markdown(
|
261 |
-
|
262 |
-
elem_classes="text-2xl font-bold text-gray-900 mb-2"
|
263 |
-
)
|
264 |
-
gr.Markdown(
|
265 |
-
"*Analyze your brand using Gemma's interpretable neural features*",
|
266 |
-
elem_classes="text-gray-600 mb-6"
|
267 |
-
)
|
268 |
|
269 |
with gr.Row():
|
270 |
-
with gr.Column():
|
271 |
input_text = gr.Textbox(
|
272 |
lines=5,
|
273 |
placeholder="Enter text to analyze...",
|
274 |
-
label="Input Text"
|
275 |
-
elem_classes="mb-4"
|
276 |
-
)
|
277 |
-
analyze_btn = gr.Button(
|
278 |
-
"Analyze Features",
|
279 |
-
variant="primary",
|
280 |
-
elem_classes="primary-btn"
|
281 |
)
|
282 |
-
|
283 |
gr.Examples(
|
284 |
examples=["WordLift", "Think Different", "Just Do It"],
|
285 |
inputs=input_text
|
286 |
)
|
287 |
|
288 |
-
with gr.Column():
|
289 |
-
|
290 |
-
feature_label = gr.Text(show_label=False, visible=False)
|
291 |
-
dashboard = gr.HTML(visible=False)
|
292 |
|
|
|
293 |
analyze_btn.click(
|
294 |
fn=analyze_features,
|
295 |
-
inputs=input_text,
|
296 |
-
outputs=[
|
297 |
)
|
298 |
-
|
299 |
return interface
|
300 |
|
301 |
if __name__ == "__main__":
|
|
|
1 |
import gradio as gr
|
2 |
import requests
|
3 |
from typing import Dict, Tuple, List
|
4 |
+
import json
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from typing import Optional
|
7 |
+
|
8 |
+
@dataclass
|
9 |
+
class Feature:
|
10 |
+
feature_id: int
|
11 |
+
activation: float
|
12 |
+
token: str
|
13 |
+
position: int
|
14 |
+
|
15 |
+
class FeatureState:
|
16 |
+
def __init__(self):
|
17 |
+
self.features_by_token = {}
|
18 |
+
self.expanded_tokens = set()
|
19 |
+
self.selected_feature = None
|
20 |
+
|
21 |
+
def get_features(text: str) -> Dict:
|
22 |
+
"""Get neural features from the API using the exact website parameters."""
|
23 |
+
url = "https://www.neuronpedia.org/api/search-with-topk"
|
24 |
+
payload = {
|
25 |
+
"modelId": "gemma-2-2b",
|
26 |
+
"text": text,
|
27 |
+
"layer": "20-gemmascope-res-16k"
|
28 |
+
}
|
29 |
+
|
30 |
+
try:
|
31 |
+
response = requests.post(
|
32 |
+
url,
|
33 |
+
headers={"Content-Type": "application/json"},
|
34 |
+
json=payload
|
35 |
+
)
|
36 |
+
response.raise_for_status()
|
37 |
+
return response.json()
|
38 |
+
except Exception as e:
|
39 |
+
return None
|
40 |
+
|
41 |
+
def format_feature_list(features: List[Feature], token: str, expanded: bool = False) -> str:
|
42 |
+
"""Format features as HTML list."""
|
43 |
+
display_features = features if expanded else features[:3]
|
44 |
+
features_html = ""
|
45 |
+
|
46 |
+
for feature in display_features:
|
47 |
+
features_html += f"""
|
48 |
+
<div class="feature-card p-4 rounded-lg mb-4 cursor-pointer hover:border-blue-500"
|
49 |
+
data-feature-id="{feature.feature_id}">
|
50 |
+
<div class="flex justify-between items-center">
|
51 |
+
<div>
|
52 |
+
<span class="font-semibold">Feature {feature.feature_id}</span>
|
53 |
+
<span class="ml-2 text-gray-600">(Activation: {feature.activation:.2f})</span>
|
54 |
+
</div>
|
55 |
+
</div>
|
56 |
+
</div>
|
57 |
+
"""
|
58 |
+
|
59 |
+
if not expanded and len(features) > 3:
|
60 |
+
remaining = len(features) - 3
|
61 |
+
features_html += f"""
|
62 |
+
<div class="text-center">
|
63 |
+
<span class="text-blue-500 text-sm">{remaining} more features available</span>
|
64 |
+
</div>
|
65 |
+
"""
|
66 |
+
|
67 |
+
return features_html
|
68 |
+
|
69 |
+
def format_dashboard(feature: Feature) -> str:
|
70 |
+
"""Format the dashboard HTML for a selected feature."""
|
71 |
+
if not feature:
|
72 |
+
return ""
|
73 |
+
|
74 |
+
return f"""
|
75 |
+
<div class="dashboard-container p-4">
|
76 |
+
<h3 class="text-lg font-semibold mb-4 text-gray-900">
|
77 |
+
Feature {feature.feature_id} Dashboard (Activation: {feature.activation:.2f})
|
78 |
+
</h3>
|
79 |
+
<iframe
|
80 |
+
src="https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{feature.feature_id}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
|
81 |
+
width="100%"
|
82 |
+
height="600"
|
83 |
+
frameborder="0"
|
84 |
+
class="rounded-lg"
|
85 |
+
></iframe>
|
86 |
+
</div>
|
87 |
+
"""
|
88 |
+
|
89 |
+
def process_features(data: Dict) -> Dict[str, List[Feature]]:
|
90 |
+
"""Process API response into features grouped by token."""
|
91 |
+
features_by_token = {}
|
92 |
+
for result in data.get('results', []):
|
93 |
+
if result['token'] == '<bos>':
|
94 |
+
continue
|
95 |
+
|
96 |
+
token = result['token']
|
97 |
+
features = []
|
98 |
+
for idx, feature in enumerate(result.get('top_features', [])):
|
99 |
+
features.append(Feature(
|
100 |
+
feature_id=feature['feature_index'],
|
101 |
+
activation=feature['activation_value'],
|
102 |
+
token=token,
|
103 |
+
position=idx
|
104 |
+
))
|
105 |
+
features_by_token[token] = features
|
106 |
+
return features_by_token
|
107 |
|
|
|
108 |
css = """
|
109 |
@import url('https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;600;700&display=swap');
|
110 |
|
|
|
112 |
font-family: 'Open Sans', sans-serif !important;
|
113 |
}
|
114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
.feature-card {
|
116 |
border: 1px solid #e0e5ff;
|
117 |
background-color: #ffffff;
|
|
|
123 |
box-shadow: 0 2px 4px rgba(52, 82, 219, 0.1);
|
124 |
}
|
125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
.dashboard-container {
|
127 |
border: 1px solid #e0e5ff;
|
128 |
border-radius: 8px;
|
|
|
130 |
}
|
131 |
"""
|
132 |
|
|
|
133 |
theme = gr.themes.Soft(
|
134 |
primary_hue=gr.themes.colors.Color(
|
135 |
name="blue",
|
|
|
147 |
)
|
148 |
)
|
149 |
|
150 |
+
def analyze_features(text: str, state: Optional[Dict] = None) -> Tuple[str, Dict]:
|
151 |
+
"""Main analysis function that processes text and returns formatted output."""
|
152 |
+
if not text:
|
153 |
+
return "", None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
|
155 |
+
data = get_features(text)
|
156 |
+
if not data:
|
157 |
+
return "Error analyzing text", None
|
158 |
+
|
159 |
+
# Process features and build state
|
160 |
+
features_by_token = process_features(data)
|
161 |
+
|
162 |
+
# Initialize state if needed
|
163 |
+
if not state:
|
164 |
+
state = {
|
165 |
+
'features_by_token': features_by_token,
|
166 |
+
'expanded_tokens': set(),
|
167 |
+
'selected_feature': None
|
|
|
|
|
|
|
|
|
|
|
168 |
}
|
169 |
+
# Select first feature as default
|
170 |
+
first_token = next(iter(features_by_token))
|
171 |
+
if features_by_token[first_token]:
|
172 |
+
state['selected_feature'] = features_by_token[first_token][0]
|
173 |
|
174 |
+
# Build output HTML
|
175 |
+
output = []
|
176 |
+
for token, features in features_by_token.items():
|
177 |
+
expanded = token in state['expanded_tokens']
|
178 |
+
token_html = f"<h2 class='text-xl font-bold mb-4'>Token: {token}</h2>"
|
179 |
+
features_html = format_feature_list(features, token, expanded)
|
180 |
+
|
181 |
+
output.append(f"<div class='mb-6'>{token_html}{features_html}</div>")
|
182 |
|
183 |
+
# Add dashboard if a feature is selected
|
184 |
+
if state['selected_feature']:
|
185 |
+
output.append(format_dashboard(state['selected_feature']))
|
|
|
|
|
186 |
|
187 |
+
return "\n".join(output), state
|
188 |
+
|
189 |
+
def toggle_expansion(token: str, state: Dict) -> Tuple[str, Dict]:
|
190 |
+
"""Toggle expansion state for a token's features."""
|
191 |
+
if token in state['expanded_tokens']:
|
192 |
+
state['expanded_tokens'].remove(token)
|
193 |
+
else:
|
194 |
+
state['expanded_tokens'].add(token)
|
195 |
|
196 |
+
output_html, state = analyze_features(None, state)
|
197 |
+
return output_html, state
|
198 |
+
|
199 |
+
def select_feature(feature_id: int, state: Dict) -> Tuple[str, Dict]:
|
200 |
+
"""Select a feature and update the dashboard."""
|
201 |
+
for features in state['features_by_token'].values():
|
202 |
+
for feature in features:
|
203 |
+
if feature.feature_id == feature_id:
|
204 |
+
state['selected_feature'] = feature
|
205 |
+
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
|
207 |
+
output_html, state = analyze_features(None, state)
|
208 |
+
return output_html, state
|
209 |
|
210 |
def create_interface():
|
211 |
+
state = gr.State({})
|
212 |
+
|
213 |
with gr.Blocks(theme=theme, css=css) as interface:
|
214 |
+
gr.Markdown("# Neural Feature Analyzer", elem_classes="text-2xl font-bold mb-2")
|
215 |
+
gr.Markdown("*Analyze text using Gemma's interpretable neural features*", elem_classes="text-gray-600 mb-6")
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
|
217 |
with gr.Row():
|
218 |
+
with gr.Column(scale=1):
|
219 |
input_text = gr.Textbox(
|
220 |
lines=5,
|
221 |
placeholder="Enter text to analyze...",
|
222 |
+
label="Input Text"
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
)
|
224 |
+
analyze_btn = gr.Button("Analyze Features", variant="primary")
|
225 |
gr.Examples(
|
226 |
examples=["WordLift", "Think Different", "Just Do It"],
|
227 |
inputs=input_text
|
228 |
)
|
229 |
|
230 |
+
with gr.Column(scale=2):
|
231 |
+
output = gr.HTML()
|
|
|
|
|
232 |
|
233 |
+
# Event handlers
|
234 |
analyze_btn.click(
|
235 |
fn=analyze_features,
|
236 |
+
inputs=[input_text, state],
|
237 |
+
outputs=[output, state]
|
238 |
)
|
239 |
+
|
240 |
return interface
|
241 |
|
242 |
if __name__ == "__main__":
|