Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,75 +1,102 @@
|
|
|
|
1 |
import requests
|
2 |
import json
|
3 |
-
from typing import Dict, List
|
4 |
-
import numpy as np
|
5 |
|
6 |
-
|
7 |
-
"
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
}
|
17 |
|
18 |
response = requests.post(
|
19 |
-
url,
|
20 |
-
headers={"Content-Type": "application/json"},
|
21 |
-
json=
|
22 |
)
|
23 |
-
return response.json()
|
24 |
-
|
25 |
-
def calculate_density(values: List[float], threshold: float = 0.5) -> float:
|
26 |
-
"""Calculate activation density (% of tokens with activation > threshold)"""
|
27 |
-
return sum(1 for v in values if v > threshold) / len(values)
|
28 |
|
29 |
-
def
|
30 |
-
|
31 |
-
|
|
|
|
|
32 |
|
33 |
-
#
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
# Process features in batches
|
39 |
-
for start_idx in range(0, 16384, batch_size):
|
40 |
-
for feature_id in range(start_idx, min(start_idx + batch_size, 16384)):
|
41 |
-
result = get_activation_values(text, feature_id)
|
42 |
-
values = result.get('values', [])
|
43 |
-
|
44 |
-
# Calculate density and skip if too high
|
45 |
-
density = calculate_density(values)
|
46 |
-
if density > max_density:
|
47 |
-
continue
|
48 |
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
})
|
57 |
|
58 |
-
#
|
59 |
-
|
60 |
-
|
61 |
-
token_features[token] = token_features[token][:num_features]
|
62 |
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
-
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
-
|
70 |
-
|
71 |
-
for token, features in token_features.items():
|
72 |
-
if features: # Only show tokens with active features
|
73 |
-
print(f"\nToken: {token}")
|
74 |
-
for feat in features:
|
75 |
-
print(f" Feature {feat['feature_id']}: activation={feat['activation']:.3f}, density={feat['density']:.3%}")
|
|
|
1 |
+
import gradio as gr
|
2 |
import requests
|
3 |
import json
|
4 |
+
from typing import Dict, List, Tuple
|
|
|
5 |
|
6 |
+
BRAND_EXAMPLES = [
|
7 |
+
"Nike - Just Do It. The power of determination.",
|
8 |
+
"Apple - Think Different. Innovation redefined.",
|
9 |
+
"McDonald's - I'm Lovin' It. Creating joy.",
|
10 |
+
"BMW - The Ultimate Driving Machine.",
|
11 |
+
"L'Oréal - Because You're Worth It."
|
12 |
+
]
|
13 |
+
|
14 |
+
def get_top_features(text: str, k: int = 5) -> Dict:
|
15 |
+
url = "https://www.neuronpedia.org/api/search-with-topk"
|
16 |
+
payload = {
|
17 |
+
"modelId": "gemma-2-2b",
|
18 |
+
"layer": "0-gemmascope-mlp-16k",
|
19 |
+
"sourceSet": "gemma-scope",
|
20 |
+
"text": text,
|
21 |
+
"k": k,
|
22 |
+
"maxDensity": 0.01,
|
23 |
+
"ignoreBos": True
|
24 |
}
|
25 |
|
26 |
response = requests.post(
|
27 |
+
url,
|
28 |
+
headers={"Content-Type": "application/json"},
|
29 |
+
json=payload
|
30 |
)
|
31 |
+
return response.json() if response.status_code == 200 else None
|
|
|
|
|
|
|
|
|
32 |
|
33 |
+
def format_output(data: Dict) -> Tuple[str, str, str]:
|
34 |
+
if not data:
|
35 |
+
return "Error analyzing text", "", ""
|
36 |
+
|
37 |
+
output = "# Neural Feature Analysis\n\n"
|
38 |
|
39 |
+
# Format token-feature analysis
|
40 |
+
for result in data['results']:
|
41 |
+
token = result['token']
|
42 |
+
if token == '<bos>': # Skip BOS token
|
43 |
+
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
+
features = result['top_features']
|
46 |
+
if features:
|
47 |
+
output += f"\n## Token: '{token}'\n"
|
48 |
+
for feat in features:
|
49 |
+
feat_index = feat['feature_index']
|
50 |
+
activation = feat['activation_value']
|
51 |
+
output += f"- **Feature {feat_index}**: activation = {activation:.2f}\n"
|
|
|
52 |
|
53 |
+
# Get highest activation feature for dashboard
|
54 |
+
max_activation = 0
|
55 |
+
max_feature = None
|
|
|
56 |
|
57 |
+
for result in data['results']:
|
58 |
+
for feature in result['top_features']:
|
59 |
+
if feature['activation_value'] > max_activation:
|
60 |
+
max_activation = feature['activation_value']
|
61 |
+
max_feature = feature['feature_index']
|
62 |
+
|
63 |
+
if max_feature:
|
64 |
+
dashboard_url = f"https://www.neuronpedia.org/gemma-2-2b/0-gemmascope-mlp-16k/{max_feature}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
|
65 |
+
iframe = f'<iframe src="{dashboard_url}" width="100%" height="600px" frameborder="0" style="border:1px solid #eee;border-radius:8px;"></iframe>'
|
66 |
+
feature_label = f"Feature {max_feature} Dashboard (Highest Activation: {max_activation:.2f})"
|
67 |
+
else:
|
68 |
+
iframe = ""
|
69 |
+
feature_label = "No significant features found"
|
70 |
+
|
71 |
+
return output, iframe, feature_label
|
72 |
|
73 |
+
def create_interface():
|
74 |
+
with gr.Blocks() as interface:
|
75 |
+
gr.Markdown("# Neural Feature Analyzer")
|
76 |
+
gr.Markdown("Analyze text using Gemma's interpretable neural features\n\nShows top 5 most activated features for each token with density < 1%")
|
77 |
+
|
78 |
+
with gr.Row():
|
79 |
+
with gr.Column():
|
80 |
+
input_text = gr.Textbox(
|
81 |
+
lines=5,
|
82 |
+
placeholder="Enter text to analyze...",
|
83 |
+
label="Input Text"
|
84 |
+
)
|
85 |
+
analyze_btn = gr.Button("Analyze Neural Features", variant="primary")
|
86 |
+
gr.Examples(BRAND_EXAMPLES, inputs=input_text)
|
87 |
+
|
88 |
+
with gr.Column():
|
89 |
+
output_text = gr.Markdown()
|
90 |
+
feature_label = gr.Text(show_label=False)
|
91 |
+
dashboard = gr.HTML()
|
92 |
+
|
93 |
+
analyze_btn.click(
|
94 |
+
fn=lambda text: format_output(get_top_features(text)),
|
95 |
+
inputs=input_text,
|
96 |
+
outputs=[output_text, dashboard, feature_label]
|
97 |
+
)
|
98 |
+
|
99 |
+
return interface
|
100 |
|
101 |
+
if __name__ == "__main__":
|
102 |
+
create_interface().launch()
|
|
|
|
|
|
|
|
|
|