brand-llms / app.py
cyberandy's picture
Update app.py
321a1b2 verified
raw
history blame
7.73 kB
import gradio as gr
import requests
from typing import Dict, List, Tuple
import json
def get_features(text: str) -> Dict:
"""Get neural features from the API."""
url = "https://www.neuronpedia.org/api/search-with-topk"
payload = {
"modelId": "gemma-2-2b",
"text": text,
"layer": "20-gemmascope-res-16k"
}
try:
response = requests.post(
url,
headers={"Content-Type": "application/json"},
json=payload
)
response.raise_for_status()
return response.json()
except Exception as e:
return None
def format_features(features_data: Dict, expanded_tokens: List[str], selected_feature: Dict) -> str:
"""Format features as HTML with expanded state."""
if not features_data or 'results' not in features_data:
return ""
output = ['<div class="p-6">']
# Process each token's features
for result in features_data['results']:
if result['token'] == '<bos>':
continue
token = result['token']
features = result['top_features']
is_expanded = token in expanded_tokens
feature_count = len(features) if is_expanded else min(3, len(features))
output.append(f'<div class="mb-8"><h2 class="text-xl font-bold mb-4">Token: {token}</h2>')
# Display features
for idx in range(feature_count):
feature = features[idx]
feature_id = feature['feature_index']
activation = feature['activation_value']
is_selected = selected_feature and selected_feature.get('feature_id') == feature_id
selected_class = "border-blue-500 border-2" if is_selected else ""
output.append(f"""
<div class="feature-card p-4 rounded-lg mb-4 hover:border-blue-500 {selected_class}">
<div class="flex justify-between items-center">
<div>
<span class="font-semibold">Feature {feature_id}</span>
<span class="ml-2 text-gray-600">(Activation: {activation:.2f})</span>
</div>
</div>
</div>
""")
# Show more/less button if needed
if len(features) > 3:
action = "less" if is_expanded else f"{len(features) - 3} more"
output.append(f"""
<div class="text-center mb-4">
<button class="text-blue-600 hover:text-blue-800 text-sm"
onclick="gradio('toggle_expansion', '{token}')">
Show {action} features
</button>
</div>
""")
output.append('</div>')
output.append('</div>')
return "\n".join(output)
def format_dashboard(feature: Dict) -> str:
"""Format the feature dashboard."""
if not feature:
return ""
feature_id = feature['feature_id']
activation = feature['activation']
return f"""
<div class="dashboard-container p-4">
<h3 class="text-lg font-semibold mb-4">
Feature {feature_id} Dashboard (Activation: {activation:.2f})
</h3>
<iframe
src="https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{feature_id}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
width="100%"
height="600"
frameborder="0"
class="rounded-lg"
></iframe>
</div>
"""
def analyze_features(text: str, state: Dict) -> Tuple[str, str, Dict]:
"""Process text and update state."""
if not text:
return "", "", state
features_data = get_features(text)
if not features_data:
return "Error analyzing text", "", state
# Update state
state['features_data'] = features_data
if not state.get('expanded_tokens'):
state['expanded_tokens'] = []
# Select first feature by default if none selected
if not state.get('selected_feature'):
for result in features_data['results']:
if result['token'] != '<bos>' and result['top_features']:
first_feature = result['top_features'][0]
state['selected_feature'] = {
'feature_id': first_feature['feature_index'],
'activation': first_feature['activation_value']
}
break
features_html = format_features(features_data, state['expanded_tokens'], state['selected_feature'])
dashboard_html = format_dashboard(state['selected_feature'])
return features_html, dashboard_html, state
def toggle_expansion(token: str, state: Dict) -> Tuple[str, str, Dict]:
"""Toggle expansion state for a token."""
if token in state['expanded_tokens']:
state['expanded_tokens'].remove(token)
else:
state['expanded_tokens'].append(token)
features_html = format_features(state['features_data'], state['expanded_tokens'], state['selected_feature'])
dashboard_html = format_dashboard(state['selected_feature'])
return features_html, dashboard_html, state
css = """
@import url('https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;600;700&display=swap');
body {
font-family: 'Open Sans', sans-serif !important;
}
.feature-card {
border: 1px solid #e0e5ff;
background-color: #ffffff;
transition: all 0.2s ease;
}
.feature-card:hover {
box-shadow: 0 2px 4px rgba(52, 82, 219, 0.1);
}
.dashboard-container {
border: 1px solid #e0e5ff;
border-radius: 8px;
background-color: #ffffff;
}
"""
theme = gr.themes.Soft(
primary_hue=gr.themes.colors.Color(
name="blue",
c50="#eef1ff", c100="#e0e5ff", c200="#c3cbff",
c300="#a5b2ff", c400="#8798ff", c500="#6a7eff",
c600="#3452db", c700="#2a41af", c800="#1f3183",
c900="#152156", c950="#0a102b",
)
)
def create_interface():
# Initialize state
state = gr.State({
'features_data': None,
'expanded_tokens': [],
'selected_feature': None
})
with gr.Blocks(theme=theme, css=css) as interface:
gr.Markdown("# Neural Feature Analyzer", elem_classes="text-2xl font-bold mb-2")
gr.Markdown("*Analyze text using Gemma's interpretable neural features*", elem_classes="text-gray-600 mb-6")
with gr.Row():
with gr.Column(scale=1):
input_text = gr.Textbox(
lines=5,
placeholder="Enter text to analyze...",
label="Input Text"
)
analyze_btn = gr.Button("Analyze Features", variant="primary")
gr.Examples(
examples=["WordLift", "Think Different", "Just Do It"],
inputs=input_text
)
with gr.Column(scale=2):
features_html = gr.HTML()
dashboard_html = gr.HTML()
# Event handlers
analyze_btn.click(
fn=analyze_features,
inputs=[input_text, state],
outputs=[features_html, dashboard_html, state]
)
# Custom JavaScript function for token expansion
interface.load(None, None, None, _js="""
function toggle_expansion(token) {
// Function will be called from HTML onclick
}
""")
return interface
if __name__ == "__main__":
create_interface().launch(share=True)