Spaces:
Running
Running
import requests | |
import gradio as gr | |
from enum import Enum | |
class Model(Enum): | |
GEMMA = "gemma-2-2b" | |
GPT2 = "gpt2-small" | |
MODEL_CONFIGS = { | |
Model.GEMMA: "20-gemmascope-res-16k", | |
Model.GPT2: "9-res-jb" | |
} | |
def get_features(text: str, model: Model): | |
url = "https://www.neuronpedia.org/api/search-with-topk" | |
payload = { | |
"modelId": model.value, | |
"text": text, | |
"layer": MODEL_CONFIGS[model] | |
} | |
try: | |
response = requests.post(url, headers={"Content-Type": "application/json"}, json=payload) | |
response.raise_for_status() | |
return response.json() | |
except Exception as e: | |
return None | |
def create_dashboard(feature_id: int, model: Model) -> str: | |
model_path = model.value.lower() | |
layer_name = MODEL_CONFIGS[model].lower() | |
return f""" | |
<div class="dashboard-container p-4"> | |
<h3 class="text-lg font-semibold mb-4">Feature {feature_id} Dashboard</h3> | |
<iframe | |
src="https://www.neuronpedia.org/{model_path}/{layer_name}/{feature_id}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300" | |
width="100%" | |
height="600" | |
frameborder="0" | |
class="rounded-lg" | |
></iframe> | |
</div> | |
""" | |
def handle_feature_click(feature_id: int, model: str): | |
selected_model = Model.GEMMA if model == "Gemini" else Model.GPT2 | |
return create_dashboard(feature_id, selected_model) | |
def analyze_text(text: str, selected_model: str): | |
model = Model.GEMMA if selected_model == "Gemini" else Model.GPT2 | |
if not text: | |
return [], "" | |
features_data = get_features(text, model) | |
if not features_data: | |
return [], "" | |
features = [] | |
first_feature_id = None | |
for result in features_data['results']: | |
if result['token'] == '<bos>': | |
continue | |
token = result['token'] | |
token_features = [] | |
for feature in result['top_features'][:3]: | |
feature_id = feature['feature_index'] | |
if first_feature_id is None: | |
first_feature_id = feature_id | |
token_features.append({ | |
"token": token, | |
"id": feature_id, | |
"activation": feature['activation_value'] | |
}) | |
features.append({"token": token, "features": token_features}) | |
return features, create_dashboard(first_feature_id, model) if first_feature_id else "" | |
css = """ | |
@import url('https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;600;700&display=swap'); | |
body { font-family: 'Open Sans', sans-serif !important; } | |
.dashboard-container { | |
border: 1px solid #e0e5ff; | |
border-radius: 8px; | |
background-color: #ffffff; | |
} | |
.token-header { | |
font-size: 1.25rem; | |
font-weight: 600; | |
margin-top: 1rem; | |
margin-bottom: 0.5rem; | |
} | |
.feature-button { | |
display: inline-block; | |
margin: 0.25rem; | |
padding: 0.5rem 1rem; | |
background-color: #f3f4f6; | |
border: 1px solid #e5e7eb; | |
border-radius: 0.375rem; | |
font-size: 0.875rem; | |
} | |
.feature-button:hover { | |
background-color: #e5e7eb; | |
} | |
.model-selector { | |
display: flex; | |
gap: 8px; | |
margin-bottom: 1rem; | |
} | |
#model-buttons .gr-form { | |
background: transparent !important; | |
border: none !important; | |
box-shadow: none !important; | |
} | |
#model-buttons .gr-radio-row { | |
gap: 8px !important; | |
} | |
#model-buttons label { | |
display: flex !important; | |
align-items: center !important; | |
gap: 4px !important; | |
padding: 4px 12px !important; | |
border: 1px solid #e5e7eb !important; | |
border-radius: 6px !important; | |
font-size: 14px !important; | |
cursor: pointer !important; | |
transition: all 0.2s !important; | |
} | |
#model-buttons label:hover { | |
background-color: #f3f4f6 !important; | |
} | |
#model-buttons label.selected { | |
background-color: #4c4ce3 !important; | |
color: white !important; | |
border-color: #4c4ce3 !important; | |
} | |
#model-buttons label:before { | |
content: "" !important; | |
width: 20px !important; | |
height: 20px !important; | |
background-size: contain !important; | |
background-repeat: no-repeat !important; | |
background-position: center !important; | |
} | |
#model-buttons label:nth-child(1):before { | |
background-image: url('img/gemini-icon.png') !important; | |
} | |
#model-buttons label:nth-child(2):before { | |
background-image: url('img/openai-icon.png') !important; | |
} | |
""" | |
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: | |
gr.Markdown("# Brand Analyzer", elem_classes="text-2xl font-bold mb-2") | |
gr.Markdown("*Analyze text using interpretable neural features*", elem_classes="text-gray-600 mb-6") | |
current_model = gr.State("Gemini") | |
features_state = gr.State([]) | |
with gr.Row(elem_classes="model-selector"): | |
with gr.Column(scale=1): | |
with gr.Row(): | |
model_choice = gr.Radio( | |
choices=["Gemini", "OpenAI"], | |
value="Gemini", | |
label="", | |
elem_classes="model-selector", | |
elem_id="model-buttons", | |
container=False, | |
interactive=True | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_text = gr.Textbox( | |
lines=5, | |
placeholder="Enter text to analyze...", | |
label="Input Text" | |
) | |
analyze_btn = gr.Button("Analyze Features", variant="primary") | |
gr.Examples( | |
examples=["WordLift", "Think Different", "Just Do It"], | |
inputs=input_text | |
) | |
with gr.Column(scale=2): | |
def render_features(features, model): | |
if not features: | |
return | |
for token_group in features: | |
gr.Markdown(f"### {token_group['token']}") | |
with gr.Row(): | |
for feature in token_group['features']: | |
btn = gr.Button( | |
f"Feature {feature['id']} (Activation: {feature['activation']:.2f})", | |
elem_classes=["feature-button"] | |
) | |
btn.click( | |
fn=lambda fid=feature['id']: handle_feature_click(fid, model), | |
outputs=dashboard | |
) | |
dashboard = gr.HTML() | |
def update_and_analyze(text, model): | |
return analyze_text(text, model) | |
model_choice.change( | |
fn=lambda x: x, | |
inputs=[model_choice], | |
outputs=[current_model] | |
) | |
analyze_btn.click( | |
fn=update_and_analyze, | |
inputs=[input_text, current_model], | |
outputs=[features_state, dashboard] | |
) | |
input_text.submit( | |
fn=update_and_analyze, | |
inputs=[input_text, current_model], | |
outputs=[features_state, dashboard] | |
) | |
if __name__ == "__main__": | |
demo.launch(share=False) |