brand-llms / app.py
cyberandy's picture
Update app.py
943551c verified
import requests
import gradio as gr
from enum import Enum
class Model(Enum):
GEMMA = "gemma-2-2b"
GPT2 = "gpt2-small"
MODEL_CONFIGS = {
Model.GEMMA: "20-gemmascope-res-16k",
Model.GPT2: "9-res-jb"
}
def get_features(text: str, model: Model):
url = "https://www.neuronpedia.org/api/search-with-topk"
payload = {
"modelId": model.value,
"text": text,
"layer": MODEL_CONFIGS[model]
}
try:
response = requests.post(url, headers={"Content-Type": "application/json"}, json=payload)
response.raise_for_status()
return response.json()
except Exception as e:
return None
def create_dashboard(feature_id: int, model: Model) -> str:
model_path = model.value.lower()
layer_name = MODEL_CONFIGS[model].lower()
return f"""
<div class="dashboard-container p-4">
<h3 class="text-lg font-semibold mb-4">Feature {feature_id} Dashboard</h3>
<iframe
src="https://www.neuronpedia.org/{model_path}/{layer_name}/{feature_id}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
width="100%"
height="600"
frameborder="0"
class="rounded-lg"
></iframe>
</div>
"""
def handle_feature_click(feature_id: int, model: str):
selected_model = Model.GEMMA if model == "Gemini" else Model.GPT2
return create_dashboard(feature_id, selected_model)
def analyze_text(text: str, selected_model: str):
model = Model.GEMMA if selected_model == "Gemini" else Model.GPT2
if not text:
return [], ""
features_data = get_features(text, model)
if not features_data:
return [], ""
features = []
first_feature_id = None
for result in features_data['results']:
if result['token'] == '<bos>':
continue
token = result['token']
token_features = []
for feature in result['top_features'][:3]:
feature_id = feature['feature_index']
if first_feature_id is None:
first_feature_id = feature_id
token_features.append({
"token": token,
"id": feature_id,
"activation": feature['activation_value']
})
features.append({"token": token, "features": token_features})
return features, create_dashboard(first_feature_id, model) if first_feature_id else ""
css = """
@import url('https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;600;700&display=swap');
body { font-family: 'Open Sans', sans-serif !important; }
.dashboard-container {
border: 1px solid #e0e5ff;
border-radius: 8px;
background-color: #ffffff;
}
.token-header {
font-size: 1.25rem;
font-weight: 600;
margin-top: 1rem;
margin-bottom: 0.5rem;
}
.feature-button {
display: inline-block;
margin: 0.25rem;
padding: 0.5rem 1rem;
background-color: #f3f4f6;
border: 1px solid #e5e7eb;
border-radius: 0.375rem;
font-size: 0.875rem;
}
.feature-button:hover {
background-color: #e5e7eb;
}
.model-selector {
display: flex;
gap: 8px;
margin-bottom: 1rem;
}
#model-buttons .gr-form {
background: transparent !important;
border: none !important;
box-shadow: none !important;
}
#model-buttons .gr-radio-row {
gap: 8px !important;
}
#model-buttons label {
display: flex !important;
align-items: center !important;
gap: 4px !important;
padding: 4px 12px !important;
border: 1px solid #e5e7eb !important;
border-radius: 6px !important;
font-size: 14px !important;
cursor: pointer !important;
transition: all 0.2s !important;
}
#model-buttons label:hover {
background-color: #f3f4f6 !important;
}
#model-buttons label.selected {
background-color: #4c4ce3 !important;
color: white !important;
border-color: #4c4ce3 !important;
}
#model-buttons label:before {
content: "" !important;
width: 20px !important;
height: 20px !important;
background-size: contain !important;
background-repeat: no-repeat !important;
background-position: center !important;
}
#model-buttons label:nth-child(1):before {
background-image: url('img/gemini-icon.png') !important;
}
#model-buttons label:nth-child(2):before {
background-image: url('img/openai-icon.png') !important;
}
"""
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
gr.Markdown("# Brand Analyzer", elem_classes="text-2xl font-bold mb-2")
gr.Markdown("*Analyze text using interpretable neural features*", elem_classes="text-gray-600 mb-6")
current_model = gr.State("Gemini")
features_state = gr.State([])
with gr.Row(elem_classes="model-selector"):
with gr.Column(scale=1):
with gr.Row():
model_choice = gr.Radio(
choices=["Gemini", "OpenAI"],
value="Gemini",
label="",
elem_classes="model-selector",
elem_id="model-buttons",
container=False,
interactive=True
)
with gr.Row():
with gr.Column(scale=1):
input_text = gr.Textbox(
lines=5,
placeholder="Enter text to analyze...",
label="Input Text"
)
analyze_btn = gr.Button("Analyze Features", variant="primary")
gr.Examples(
examples=["WordLift", "Think Different", "Just Do It"],
inputs=input_text
)
with gr.Column(scale=2):
@gr.render(inputs=[features_state, current_model])
def render_features(features, model):
if not features:
return
for token_group in features:
gr.Markdown(f"### {token_group['token']}")
with gr.Row():
for feature in token_group['features']:
btn = gr.Button(
f"Feature {feature['id']} (Activation: {feature['activation']:.2f})",
elem_classes=["feature-button"]
)
btn.click(
fn=lambda fid=feature['id']: handle_feature_click(fid, model),
outputs=dashboard
)
dashboard = gr.HTML()
def update_and_analyze(text, model):
return analyze_text(text, model)
model_choice.change(
fn=lambda x: x,
inputs=[model_choice],
outputs=[current_model]
)
analyze_btn.click(
fn=update_and_analyze,
inputs=[input_text, current_model],
outputs=[features_state, dashboard]
)
input_text.submit(
fn=update_and_analyze,
inputs=[input_text, current_model],
outputs=[features_state, dashboard]
)
if __name__ == "__main__":
demo.launch(share=False)