Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -4,7 +4,7 @@ import requests
|
|
4 |
def get_features(text: str):
|
5 |
url = "https://www.neuronpedia.org/api/search-with-topk"
|
6 |
payload = {
|
7 |
-
"modelId": "gemma-2-2b",
|
8 |
"text": text,
|
9 |
"layer": "20-gemmascope-res-16k"
|
10 |
}
|
@@ -29,15 +29,18 @@ def create_dashboard(feature_id: int) -> str:
|
|
29 |
</div>
|
30 |
"""
|
31 |
|
|
|
|
|
|
|
32 |
def analyze_text(text: str):
|
33 |
if not text:
|
34 |
-
return
|
35 |
-
|
36 |
features_data = get_features(text)
|
37 |
if not features_data:
|
38 |
-
return
|
39 |
|
40 |
-
|
41 |
first_feature_id = None
|
42 |
|
43 |
for result in features_data['results']:
|
@@ -45,24 +48,22 @@ def analyze_text(text: str):
|
|
45 |
continue
|
46 |
|
47 |
token = result['token']
|
48 |
-
|
49 |
|
50 |
for feature in result['top_features'][:3]:
|
51 |
feature_id = feature['feature_index']
|
52 |
if first_feature_id is None:
|
53 |
first_feature_id = feature_id
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
html += "</div>"
|
63 |
-
initial_dashboard = create_dashboard(first_feature_id) if first_feature_id else ""
|
64 |
|
65 |
-
return
|
66 |
|
67 |
css = """
|
68 |
@import url('https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;600;700&display=swap');
|
@@ -75,19 +76,21 @@ body { font-family: 'Open Sans', sans-serif !important; }
|
|
75 |
background-color: #ffffff;
|
76 |
}
|
77 |
|
78 |
-
.
|
79 |
-
|
80 |
font-weight: 600;
|
|
|
|
|
81 |
}
|
82 |
|
83 |
.feature-button {
|
84 |
-
display: block;
|
85 |
-
margin: 0.
|
86 |
padding: 0.5rem 1rem;
|
87 |
background-color: #f3f4f6;
|
88 |
border: 1px solid #e5e7eb;
|
89 |
border-radius: 0.375rem;
|
90 |
-
|
91 |
}
|
92 |
|
93 |
.feature-button:hover {
|
@@ -95,23 +98,12 @@ body { font-family: 'Open Sans', sans-serif !important; }
|
|
95 |
}
|
96 |
"""
|
97 |
|
98 |
-
theme
|
99 |
-
primary_hue=gr.themes.colors.Color(
|
100 |
-
name="blue",
|
101 |
-
c50="#eef1ff", c100="#e0e5ff", c200="#c3cbff",
|
102 |
-
c300="#a5b2ff", c400="#8798ff", c500="#6a7eff",
|
103 |
-
c600="#3452db", c700="#2a41af", c800="#1f3183",
|
104 |
-
c900="#152156", c950="#0a102b",
|
105 |
-
)
|
106 |
-
)
|
107 |
-
|
108 |
-
def update_dashboard(feature_id: int):
|
109 |
-
return create_dashboard(feature_id)
|
110 |
-
|
111 |
-
with gr.Blocks(theme=theme, css=css) as demo:
|
112 |
gr.Markdown("# Brand Analyzer", elem_classes="text-2xl font-bold mb-2")
|
113 |
gr.Markdown("*Analyze text using Gemma's interpretable neural features*", elem_classes="text-gray-600 mb-6")
|
114 |
|
|
|
|
|
115 |
with gr.Row():
|
116 |
with gr.Column(scale=1):
|
117 |
input_text = gr.Textbox(
|
@@ -126,17 +118,32 @@ with gr.Blocks(theme=theme, css=css) as demo:
|
|
126 |
)
|
127 |
|
128 |
with gr.Column(scale=2):
|
129 |
-
|
130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
-
|
133 |
-
dashboard.change(fn=update_dashboard, inputs=gr.Textbox(visible=False), outputs=dashboard)
|
134 |
|
135 |
analyze_btn.click(
|
136 |
fn=analyze_text,
|
137 |
inputs=[input_text],
|
138 |
-
outputs=[
|
139 |
)
|
140 |
|
141 |
if __name__ == "__main__":
|
142 |
-
demo.launch(share=
|
|
|
4 |
def get_features(text: str):
|
5 |
url = "https://www.neuronpedia.org/api/search-with-topk"
|
6 |
payload = {
|
7 |
+
"modelId": "gemma-2-2b",
|
8 |
"text": text,
|
9 |
"layer": "20-gemmascope-res-16k"
|
10 |
}
|
|
|
29 |
</div>
|
30 |
"""
|
31 |
|
32 |
+
def handle_feature_click(feature_id):
|
33 |
+
return create_dashboard(feature_id)
|
34 |
+
|
35 |
def analyze_text(text: str):
|
36 |
if not text:
|
37 |
+
return [], ""
|
38 |
+
|
39 |
features_data = get_features(text)
|
40 |
if not features_data:
|
41 |
+
return [], ""
|
42 |
|
43 |
+
features = []
|
44 |
first_feature_id = None
|
45 |
|
46 |
for result in features_data['results']:
|
|
|
48 |
continue
|
49 |
|
50 |
token = result['token']
|
51 |
+
token_features = []
|
52 |
|
53 |
for feature in result['top_features'][:3]:
|
54 |
feature_id = feature['feature_index']
|
55 |
if first_feature_id is None:
|
56 |
first_feature_id = feature_id
|
57 |
+
|
58 |
+
token_features.append({
|
59 |
+
"token": token,
|
60 |
+
"id": feature_id,
|
61 |
+
"activation": feature['activation_value']
|
62 |
+
})
|
63 |
+
|
64 |
+
features.append({"token": token, "features": token_features})
|
|
|
|
|
65 |
|
66 |
+
return features, create_dashboard(first_feature_id) if first_feature_id else ""
|
67 |
|
68 |
css = """
|
69 |
@import url('https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;600;700&display=swap');
|
|
|
76 |
background-color: #ffffff;
|
77 |
}
|
78 |
|
79 |
+
.token-header {
|
80 |
+
font-size: 1.25rem;
|
81 |
font-weight: 600;
|
82 |
+
margin-top: 1rem;
|
83 |
+
margin-bottom: 0.5rem;
|
84 |
}
|
85 |
|
86 |
.feature-button {
|
87 |
+
display: inline-block;
|
88 |
+
margin: 0.25rem;
|
89 |
padding: 0.5rem 1rem;
|
90 |
background-color: #f3f4f6;
|
91 |
border: 1px solid #e5e7eb;
|
92 |
border-radius: 0.375rem;
|
93 |
+
font-size: 0.875rem;
|
94 |
}
|
95 |
|
96 |
.feature-button:hover {
|
|
|
98 |
}
|
99 |
"""
|
100 |
|
101 |
+
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
gr.Markdown("# Brand Analyzer", elem_classes="text-2xl font-bold mb-2")
|
103 |
gr.Markdown("*Analyze text using Gemma's interpretable neural features*", elem_classes="text-gray-600 mb-6")
|
104 |
|
105 |
+
features_state = gr.State([])
|
106 |
+
|
107 |
with gr.Row():
|
108 |
with gr.Column(scale=1):
|
109 |
input_text = gr.Textbox(
|
|
|
118 |
)
|
119 |
|
120 |
with gr.Column(scale=2):
|
121 |
+
@gr.render(every=0.5)
|
122 |
+
def render_features():
|
123 |
+
features = features_state.value
|
124 |
+
if not features:
|
125 |
+
return
|
126 |
+
|
127 |
+
for token_group in features:
|
128 |
+
gr.Markdown(f"### {token_group['token']}")
|
129 |
+
with gr.Row():
|
130 |
+
for feature in token_group['features']:
|
131 |
+
btn = gr.Button(
|
132 |
+
f"Feature {feature['id']} (Activation: {feature['activation']:.2f})",
|
133 |
+
elem_classes=["feature-button"]
|
134 |
+
)
|
135 |
+
btn.click(
|
136 |
+
fn=lambda fid=feature['id']: handle_feature_click(fid),
|
137 |
+
outputs=dashboard
|
138 |
+
)
|
139 |
|
140 |
+
dashboard = gr.HTML()
|
|
|
141 |
|
142 |
analyze_btn.click(
|
143 |
fn=analyze_text,
|
144 |
inputs=[input_text],
|
145 |
+
outputs=[features_state, dashboard]
|
146 |
)
|
147 |
|
148 |
if __name__ == "__main__":
|
149 |
+
demo.launch(share=False)
|