cyberandy commited on
Commit
30ede40
·
verified ·
1 Parent(s): 383d1f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -42
app.py CHANGED
@@ -4,7 +4,7 @@ import requests
4
  def get_features(text: str):
5
  url = "https://www.neuronpedia.org/api/search-with-topk"
6
  payload = {
7
- "modelId": "gemma-2-2b",
8
  "text": text,
9
  "layer": "20-gemmascope-res-16k"
10
  }
@@ -29,15 +29,18 @@ def create_dashboard(feature_id: int) -> str:
29
  </div>
30
  """
31
 
 
 
 
32
  def analyze_text(text: str):
33
  if not text:
34
- return gr.update(visible=False), ""
35
-
36
  features_data = get_features(text)
37
  if not features_data:
38
- return gr.update(visible=False), ""
39
 
40
- html = "<div class='features-list'>"
41
  first_feature_id = None
42
 
43
  for result in features_data['results']:
@@ -45,24 +48,22 @@ def analyze_text(text: str):
45
  continue
46
 
47
  token = result['token']
48
- html += f"<h3>{token}</h3>"
49
 
50
  for feature in result['top_features'][:3]:
51
  feature_id = feature['feature_index']
52
  if first_feature_id is None:
53
  first_feature_id = feature_id
54
-
55
- html += f"""
56
- <button onclick='document.dispatchEvent(new CustomEvent("select_feature",
57
- {{detail: {{feature_id: {feature_id}}}}}))' class='feature-button'>
58
- Feature {feature_id} (Activation: {feature['activation_value']:.2f})
59
- </button>
60
- """
61
-
62
- html += "</div>"
63
- initial_dashboard = create_dashboard(first_feature_id) if first_feature_id else ""
64
 
65
- return html, initial_dashboard
66
 
67
  css = """
68
  @import url('https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;600;700&display=swap');
@@ -75,19 +76,21 @@ body { font-family: 'Open Sans', sans-serif !important; }
75
  background-color: #ffffff;
76
  }
77
 
78
- .features-list h3 {
79
- margin-top: 1rem;
80
  font-weight: 600;
 
 
81
  }
82
 
83
  .feature-button {
84
- display: block;
85
- margin: 0.5rem 0;
86
  padding: 0.5rem 1rem;
87
  background-color: #f3f4f6;
88
  border: 1px solid #e5e7eb;
89
  border-radius: 0.375rem;
90
- cursor: pointer;
91
  }
92
 
93
  .feature-button:hover {
@@ -95,23 +98,12 @@ body { font-family: 'Open Sans', sans-serif !important; }
95
  }
96
  """
97
 
98
- theme = gr.themes.Soft(
99
- primary_hue=gr.themes.colors.Color(
100
- name="blue",
101
- c50="#eef1ff", c100="#e0e5ff", c200="#c3cbff",
102
- c300="#a5b2ff", c400="#8798ff", c500="#6a7eff",
103
- c600="#3452db", c700="#2a41af", c800="#1f3183",
104
- c900="#152156", c950="#0a102b",
105
- )
106
- )
107
-
108
- def update_dashboard(feature_id: int):
109
- return create_dashboard(feature_id)
110
-
111
- with gr.Blocks(theme=theme, css=css) as demo:
112
  gr.Markdown("# Brand Analyzer", elem_classes="text-2xl font-bold mb-2")
113
  gr.Markdown("*Analyze text using Gemma's interpretable neural features*", elem_classes="text-gray-600 mb-6")
114
 
 
 
115
  with gr.Row():
116
  with gr.Column(scale=1):
117
  input_text = gr.Textbox(
@@ -126,17 +118,32 @@ with gr.Blocks(theme=theme, css=css) as demo:
126
  )
127
 
128
  with gr.Column(scale=2):
129
- features_html = gr.HTML()
130
- dashboard = gr.HTML()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
- # Handle feature selection via JavaScript
133
- dashboard.change(fn=update_dashboard, inputs=gr.Textbox(visible=False), outputs=dashboard)
134
 
135
  analyze_btn.click(
136
  fn=analyze_text,
137
  inputs=[input_text],
138
- outputs=[features_html, dashboard]
139
  )
140
 
141
  if __name__ == "__main__":
142
- demo.launch(share=True)
 
4
  def get_features(text: str):
5
  url = "https://www.neuronpedia.org/api/search-with-topk"
6
  payload = {
7
+ "modelId": "gemma-2-2b",
8
  "text": text,
9
  "layer": "20-gemmascope-res-16k"
10
  }
 
29
  </div>
30
  """
31
 
32
+ def handle_feature_click(feature_id):
33
+ return create_dashboard(feature_id)
34
+
35
  def analyze_text(text: str):
36
  if not text:
37
+ return [], ""
38
+
39
  features_data = get_features(text)
40
  if not features_data:
41
+ return [], ""
42
 
43
+ features = []
44
  first_feature_id = None
45
 
46
  for result in features_data['results']:
 
48
  continue
49
 
50
  token = result['token']
51
+ token_features = []
52
 
53
  for feature in result['top_features'][:3]:
54
  feature_id = feature['feature_index']
55
  if first_feature_id is None:
56
  first_feature_id = feature_id
57
+
58
+ token_features.append({
59
+ "token": token,
60
+ "id": feature_id,
61
+ "activation": feature['activation_value']
62
+ })
63
+
64
+ features.append({"token": token, "features": token_features})
 
 
65
 
66
+ return features, create_dashboard(first_feature_id) if first_feature_id else ""
67
 
68
  css = """
69
  @import url('https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;600;700&display=swap');
 
76
  background-color: #ffffff;
77
  }
78
 
79
+ .token-header {
80
+ font-size: 1.25rem;
81
  font-weight: 600;
82
+ margin-top: 1rem;
83
+ margin-bottom: 0.5rem;
84
  }
85
 
86
  .feature-button {
87
+ display: inline-block;
88
+ margin: 0.25rem;
89
  padding: 0.5rem 1rem;
90
  background-color: #f3f4f6;
91
  border: 1px solid #e5e7eb;
92
  border-radius: 0.375rem;
93
+ font-size: 0.875rem;
94
  }
95
 
96
  .feature-button:hover {
 
98
  }
99
  """
100
 
101
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  gr.Markdown("# Brand Analyzer", elem_classes="text-2xl font-bold mb-2")
103
  gr.Markdown("*Analyze text using Gemma's interpretable neural features*", elem_classes="text-gray-600 mb-6")
104
 
105
+ features_state = gr.State([])
106
+
107
  with gr.Row():
108
  with gr.Column(scale=1):
109
  input_text = gr.Textbox(
 
118
  )
119
 
120
  with gr.Column(scale=2):
121
+ @gr.render(every=0.5)
122
+ def render_features():
123
+ features = features_state.value
124
+ if not features:
125
+ return
126
+
127
+ for token_group in features:
128
+ gr.Markdown(f"### {token_group['token']}")
129
+ with gr.Row():
130
+ for feature in token_group['features']:
131
+ btn = gr.Button(
132
+ f"Feature {feature['id']} (Activation: {feature['activation']:.2f})",
133
+ elem_classes=["feature-button"]
134
+ )
135
+ btn.click(
136
+ fn=lambda fid=feature['id']: handle_feature_click(fid),
137
+ outputs=dashboard
138
+ )
139
 
140
+ dashboard = gr.HTML()
 
141
 
142
  analyze_btn.click(
143
  fn=analyze_text,
144
  inputs=[input_text],
145
+ outputs=[features_state, dashboard]
146
  )
147
 
148
  if __name__ == "__main__":
149
+ demo.launch(share=False)