cyberandy commited on
Commit
44c881e
·
verified ·
1 Parent(s): 5ee5132

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -61
app.py CHANGED
@@ -1,75 +1,102 @@
 
1
  import requests
2
  import json
3
- from typing import Dict, List
4
- import numpy as np
5
 
6
- def get_activation_values(text: str, feature_id: int) -> Dict:
7
- """Get activation values for a specific feature"""
8
- url = "https://www.neuronpedia.org/api/activation/new"
9
- data = {
10
- "feature": {
11
- "modelId": "gemma-2-2b",
12
- "layer": "0-gemmascope-mlp-16k",
13
- "index": str(feature_id)
14
- },
15
- "customText": text
 
 
 
 
 
 
 
 
16
  }
17
 
18
  response = requests.post(
19
- url,
20
- headers={"Content-Type": "application/json"},
21
- json=data
22
  )
23
- return response.json()
24
-
25
- def calculate_density(values: List[float], threshold: float = 0.5) -> float:
26
- """Calculate activation density (% of tokens with activation > threshold)"""
27
- return sum(1 for v in values if v > threshold) / len(values)
28
 
29
- def find_top_features_per_token(text: str, num_features: int = 5,
30
- max_density: float = 0.01, batch_size: int = 100) -> Dict:
31
- """Find top features for each token with density filtering"""
 
 
32
 
33
- # First get initial feature activations to get tokens
34
- sample_activation = get_activation_values(text, 0)
35
- tokens = sample_activation['tokens']
36
- token_features = {token: [] for token in tokens}
37
-
38
- # Process features in batches
39
- for start_idx in range(0, 16384, batch_size):
40
- for feature_id in range(start_idx, min(start_idx + batch_size, 16384)):
41
- result = get_activation_values(text, feature_id)
42
- values = result.get('values', [])
43
-
44
- # Calculate density and skip if too high
45
- density = calculate_density(values)
46
- if density > max_density:
47
- continue
48
 
49
- # Add feature to each token's list if activated
50
- for token_idx, (token, value) in enumerate(zip(tokens, values)):
51
- if value > 0.5: # Activation threshold
52
- token_features[token].append({
53
- 'feature_id': feature_id,
54
- 'activation': value,
55
- 'density': density
56
- })
57
 
58
- # Sort features for each token and keep top N
59
- for token in token_features:
60
- token_features[token].sort(key=lambda x: x['activation'], reverse=True)
61
- token_features[token] = token_features[token][:num_features]
62
 
63
- return token_features
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
- # Test the function
66
- text = "Nike - Just Do It"
67
- token_features = find_top_features_per_token(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- # Print results
70
- print(f"Text: {text}\n")
71
- for token, features in token_features.items():
72
- if features: # Only show tokens with active features
73
- print(f"\nToken: {token}")
74
- for feat in features:
75
- print(f" Feature {feat['feature_id']}: activation={feat['activation']:.3f}, density={feat['density']:.3%}")
 
1
+ import gradio as gr
2
  import requests
3
  import json
4
+ from typing import Dict, List, Tuple
 
5
 
6
+ BRAND_EXAMPLES = [
7
+ "Nike - Just Do It. The power of determination.",
8
+ "Apple - Think Different. Innovation redefined.",
9
+ "McDonald's - I'm Lovin' It. Creating joy.",
10
+ "BMW - The Ultimate Driving Machine.",
11
+ "L'Oréal - Because You're Worth It."
12
+ ]
13
+
14
+ def get_top_features(text: str, k: int = 5) -> Dict:
15
+ url = "https://www.neuronpedia.org/api/search-with-topk"
16
+ payload = {
17
+ "modelId": "gemma-2-2b",
18
+ "layer": "0-gemmascope-mlp-16k",
19
+ "sourceSet": "gemma-scope",
20
+ "text": text,
21
+ "k": k,
22
+ "maxDensity": 0.01,
23
+ "ignoreBos": True
24
  }
25
 
26
  response = requests.post(
27
+ url,
28
+ headers={"Content-Type": "application/json"},
29
+ json=payload
30
  )
31
+ return response.json() if response.status_code == 200 else None
 
 
 
 
32
 
33
+ def format_output(data: Dict) -> Tuple[str, str, str]:
34
+ if not data:
35
+ return "Error analyzing text", "", ""
36
+
37
+ output = "# Neural Feature Analysis\n\n"
38
 
39
+ # Format token-feature analysis
40
+ for result in data['results']:
41
+ token = result['token']
42
+ if token == '<bos>': # Skip BOS token
43
+ continue
 
 
 
 
 
 
 
 
 
 
44
 
45
+ features = result['top_features']
46
+ if features:
47
+ output += f"\n## Token: '{token}'\n"
48
+ for feat in features:
49
+ feat_index = feat['feature_index']
50
+ activation = feat['activation_value']
51
+ output += f"- **Feature {feat_index}**: activation = {activation:.2f}\n"
 
52
 
53
+ # Get highest activation feature for dashboard
54
+ max_activation = 0
55
+ max_feature = None
 
56
 
57
+ for result in data['results']:
58
+ for feature in result['top_features']:
59
+ if feature['activation_value'] > max_activation:
60
+ max_activation = feature['activation_value']
61
+ max_feature = feature['feature_index']
62
+
63
+ if max_feature:
64
+ dashboard_url = f"https://www.neuronpedia.org/gemma-2-2b/0-gemmascope-mlp-16k/{max_feature}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
65
+ iframe = f'<iframe src="{dashboard_url}" width="100%" height="600px" frameborder="0" style="border:1px solid #eee;border-radius:8px;"></iframe>'
66
+ feature_label = f"Feature {max_feature} Dashboard (Highest Activation: {max_activation:.2f})"
67
+ else:
68
+ iframe = ""
69
+ feature_label = "No significant features found"
70
+
71
+ return output, iframe, feature_label
72
 
73
+ def create_interface():
74
+ with gr.Blocks() as interface:
75
+ gr.Markdown("# Neural Feature Analyzer")
76
+ gr.Markdown("Analyze text using Gemma's interpretable neural features\n\nShows top 5 most activated features for each token with density < 1%")
77
+
78
+ with gr.Row():
79
+ with gr.Column():
80
+ input_text = gr.Textbox(
81
+ lines=5,
82
+ placeholder="Enter text to analyze...",
83
+ label="Input Text"
84
+ )
85
+ analyze_btn = gr.Button("Analyze Neural Features", variant="primary")
86
+ gr.Examples(BRAND_EXAMPLES, inputs=input_text)
87
+
88
+ with gr.Column():
89
+ output_text = gr.Markdown()
90
+ feature_label = gr.Text(show_label=False)
91
+ dashboard = gr.HTML()
92
+
93
+ analyze_btn.click(
94
+ fn=lambda text: format_output(get_top_features(text)),
95
+ inputs=input_text,
96
+ outputs=[output_text, dashboard, feature_label]
97
+ )
98
+
99
+ return interface
100
 
101
+ if __name__ == "__main__":
102
+ create_interface().launch()