import gradio as gr import torch import torch.nn as nn from transformers import AutoModelForCausalLM, AutoTokenizer from huggingface_hub import hf_hub_download import numpy as np from dataclasses import dataclass from typing import List, Dict, Optional import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @dataclass class MarketingFeature: """Structure to hold marketing-relevant feature information""" feature_id: int name: str category: str description: str interpretation_guide: str layer: int threshold: float = 0.1 # Define relevant features MARKETING_FEATURES = [ MarketingFeature( feature_id=35, name="Technical Term Detector", category="technical", description="Detects technical and specialized terminology", interpretation_guide="High activation indicates strong technical focus", layer=20, ), MarketingFeature( feature_id=6680, name="Compound Technical Terms", category="technical", description="Identifies complex technical concepts", interpretation_guide="Consider simplifying language if activation is too high", layer=20, ), MarketingFeature( feature_id=2, name="SEO Keyword Detector", category="seo", description="Identifies potential SEO keywords", interpretation_guide="High activation suggests strong SEO potential", layer=20, ), ] class JumpReLUSAE(nn.Module): def __init__(self, d_model, d_sae): super().__init__() self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae)) self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model)) self.threshold = nn.Parameter(torch.zeros(d_sae)) self.b_enc = nn.Parameter(torch.zeros(d_sae)) self.b_dec = nn.Parameter(torch.zeros(d_model)) def encode(self, input_acts): pre_acts = input_acts @ self.W_enc + self.b_enc mask = pre_acts > self.threshold acts = mask * torch.nn.functional.relu(pre_acts) return acts def decode(self, acts): return acts @ self.W_dec + self.b_dec def forward(self, acts): acts = self.encode(acts) recon = self.decode(acts) return recon class MarketingAnalyzer: def __init__(self): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.set_grad_enabled(False) # Avoid memory issues self._initialize_model() self._load_saes() def _initialize_model(self): try: self.model = AutoModelForCausalLM.from_pretrained( "google/gemma-2-2b", device_map="auto" ) self.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") self.model.eval() logger.info("Model initialized successfully") except Exception as e: logger.error(f"Error initializing model: {str(e)}") raise def _load_saes(self): self.saes = {} for feature in MARKETING_FEATURES: try: path = hf_hub_download( repo_id="google/gemma-scope-2b-pt-res", filename=f"layer_{feature.layer}/width_16k/average_l0_71/params.npz", force_download=False, ) params = np.load(path) # Create SAE d_model = params["W_enc"].shape[0] d_sae = params["W_enc"].shape[1] sae = JumpReLUSAE(d_model, d_sae).to(self.device) # Load parameters sae_params = { k: torch.from_numpy(v).to(self.device) for k, v in params.items() } sae.load_state_dict(sae_params) self.saes[feature.feature_id] = {"sae": sae, "feature": feature} logger.info(f"Loaded SAE for feature {feature.feature_id}") except Exception as e: logger.error( f"Error loading SAE for feature {feature.feature_id}: {str(e)}" ) continue def _gather_activations(self, text: str, layer: int): inputs = self.tokenizer(text, return_tensors="pt").to(self.device) target_act = None def hook(mod, inputs, outputs): nonlocal target_act target_act = outputs[0] return outputs handle = self.model.model.layers[layer].register_forward_hook(hook) with torch.no_grad(): _ = self.model(**inputs) handle.remove() return target_act, inputs def analyze_content(self, text: str) -> Dict: results = { "text": text, "features": {}, "categories": {}, "recommendations": [], } try: # Get activations for each feature for feature_id, sae_data in self.saes.items(): feature = sae_data["feature"] sae = sae_data["sae"] # Get layer activations activations, inputs = self._gather_activations(text, feature.layer) # Skip BOS token and get activations sae_acts = sae.encode(activations.to(torch.float32)) sae_acts = sae_acts[:, 1:] # Skip BOS token # Calculate metrics if sae_acts.numel() > 0: mean_activation = float(sae_acts.mean()) max_activation = float(sae_acts.max()) else: mean_activation = 0.0 max_activation = 0.0 # Record results feature_result = { "name": feature.name, "category": feature.category, "activation_score": mean_activation, "max_activation": max_activation, "interpretation": self._interpret_activation( mean_activation, feature ), } results["features"][feature_id] = feature_result if feature.category not in results["categories"]: results["categories"][feature.category] = [] results["categories"][feature.category].append(feature_result) results["recommendations"] = self._generate_recommendations(results) except Exception as e: logger.error(f"Error analyzing content: {str(e)}") raise return results def _interpret_activation( self, activation: float, feature: MarketingFeature ) -> str: if activation > 0.8: return f"Very strong presence of {feature.name.lower()}" elif activation > 0.5: return f"Moderate presence of {feature.name.lower()}" else: return f"Limited presence of {feature.name.lower()}" def _generate_recommendations(self, results: Dict) -> List[str]: recommendations = [] try: tech_features = [ f for f in results["features"].values() if f["category"] == "technical" ] if tech_features: tech_score = np.mean([f["activation_score"] for f in tech_features]) if tech_score > 0.8: recommendations.append( "Consider simplifying technical language for broader audience" ) elif tech_score < 0.3: recommendations.append( "Could benefit from more specific technical details" ) except Exception as e: logger.error(f"Error generating recommendations: {str(e)}") return recommendations def create_gradio_interface(): try: analyzer = MarketingAnalyzer() except Exception as e: logger.error(f"Failed to initialize analyzer: {str(e)}") return gr.Interface( fn=lambda x: "Error: Failed to initialize model. Please check authentication.", inputs=gr.Textbox(), outputs=gr.Textbox(), title="Marketing Content Analyzer (Error)", description="Failed to initialize.", ) def analyze(text): results = analyzer.analyze_content(text) output = "# Content Analysis Results\n\n" output += "## Category Scores\n" for category, features in results["categories"].items(): if features: avg_score = np.mean([f["activation_score"] for f in features]) output += f"**{category.title()}**: {avg_score:.2f}\n" output += "\n## Feature Details\n" for feature_id, feature in results["features"].items(): output += f"\n### {feature['name']} (Feature {feature_id})\n" output += f"**Score**: {feature['activation_score']:.2f}\n\n" output += f"**Interpretation**: {feature['interpretation']}\n\n" # Add feature explanation from Neuronpedia reference output += f"[View feature details on Neuronpedia](https://neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{feature_id})\n\n" if results["recommendations"]: output += "\n## Recommendations\n" for rec in results["recommendations"]: output += f"- {rec}\n" feature_id = max( results["features"].items(), key=lambda x: x[1]["activation_score"] )[0] # Build dashboard URL for the highest activating feature dashboard_url = f"https://neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{feature_id}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300" return output, dashboard_url, feature_id with gr.Blocks( theme=gr.themes.Default( font=[gr.themes.GoogleFont("Open Sans"), "Arial", "sans-serif"], primary_hue="indigo", secondary_hue="blue", neutral_hue="gray", ) ) as interface: gr.Markdown("# Marketing Content Analyzer") gr.Markdown( "Analyze your marketing content using Gemma Scope's neural features" ) with gr.Row(): with gr.Column(scale=1): input_text = gr.Textbox( lines=5, placeholder="Enter your marketing content here...", label="Marketing Content", ) analyze_btn = gr.Button("Analyze", variant="primary") gr.Examples( examples=[ "WordLift is an AI-powered SEO tool", "Our advanced machine learning algorithms optimize your content", "Simple and effective website optimization", ], inputs=input_text, ) with gr.Column(scale=2): output_text = gr.Markdown(label="Analysis Results") with gr.Box(): gr.Markdown("## Feature Dashboard") feature_id_text = gr.Text( label="Currently viewing feature", show_label=False ) dashboard_frame = gr.HTML(label="Feature Dashboard") def update_dashboard(text): output, dashboard_url, feature_id = analyze(text) return ( output, f"", f"Currently viewing Feature {feature_id} - Most active feature in your content", ) analyze_btn.click( fn=update_dashboard, inputs=input_text, outputs=[output_text, dashboard_frame, feature_id_text], ) return interface if __name__ == "__main__": iface = create_gradio_interface() iface.launch()