brand-llms / app.py
cyberandy's picture
Update app.py
deaf693 verified
raw
history blame
9.93 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download
import numpy as np
from dataclasses import dataclass
from typing import List, Dict, Optional
import logging
# Initialize logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class MarketingFeature:
"""Structure to hold marketing-relevant feature information"""
feature_id: int
name: str
category: str
description: str
interpretation_guide: str
layer: int
threshold: float = 0.1
# Define marketing-relevant features from Gemma Scope
MARKETING_FEATURES = [
MarketingFeature(
feature_id=35,
name="Technical Term Detector",
category="technical",
description="Detects technical and specialized terminology",
interpretation_guide="High activation indicates strong technical focus",
layer=20
),
MarketingFeature(
feature_id=6680,
name="Compound Technical Terms",
category="technical",
description="Identifies complex technical concepts",
interpretation_guide="Consider simplifying language if activation is too high",
layer=20
),
MarketingFeature(
feature_id=2,
name="SEO Keyword Detector",
category="seo",
description="Identifies potential SEO keywords",
interpretation_guide="High activation suggests strong SEO potential",
layer=20
),
# Add more relevant features as we discover them
]
class MarketingAnalyzer:
"""Main class for analyzing marketing content using Gemma Scope"""
def __init__(self, model_size: str = "2b"):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self._initialize_model(model_size)
self._load_saes()
def _initialize_model(self, model_size: str):
"""Initialize Gemma model and tokenizer"""
try:
import os
model_name = f"google/gemma-{model_size}"
# Access HF token from environment variable
hf_token = os.environ.get('HF_TOKEN')
if not hf_token:
logger.warning("HF_TOKEN not found in environment variables")
# Initialize model and tokenizer with token
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
token=hf_token,
device_map='auto' # Automatically handle device placement
)
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
token=hf_token
)
self.model.eval() # Set to evaluation mode
logger.info(f"Initialized model: {model_name}")
except Exception as e:
logger.error(f"Error initializing model: {str(e)}")
raise
def _load_saes(self):
"""Load relevant SAEs from Gemma Scope"""
self.saes = {}
for feature in MARKETING_FEATURES:
try:
# Load SAE parameters for each feature
path = hf_hub_download(
repo_id=f"google/gemma-scope-{self.model_size}-pt-res",
filename=f"layer_{feature.layer}/width_16k/average_l0_71/params.npz"
)
params = np.load(path)
self.saes[feature.feature_id] = {
'params': {k: torch.from_numpy(v).cuda() for k, v in params.items()},
'feature': feature
}
logger.info(f"Loaded SAE for feature {feature.feature_id}")
except Exception as e:
logger.error(f"Error loading SAE for feature {feature.feature_id}: {str(e)}")
continue
def analyze_content(self, text: str) -> Dict:
"""Analyze marketing content using loaded SAEs"""
results = {
'text': text,
'features': {},
'categories': {},
'recommendations': []
}
try:
# Get model activations
inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model(**inputs, output_hidden_states=True)
# Analyze each feature
for feature_id, sae_data in self.saes.items():
feature = sae_data['feature']
layer_output = outputs.hidden_states[feature.layer]
# Apply SAE
activations = self._apply_sae(
layer_output,
sae_data['params'],
feature.threshold
)
# Record results
feature_result = {
'name': feature.name,
'category': feature.category,
'activation_score': float(activations.mean()),
'max_activation': float(activations.max()),
'interpretation': self._interpret_activation(
activations,
feature
)
}
results['features'][feature_id] = feature_result
# Aggregate by category
if feature.category not in results['categories']:
results['categories'][feature.category] = []
results['categories'][feature.category].append(feature_result)
# Generate recommendations
results['recommendations'] = self._generate_recommendations(results)
except Exception as e:
logger.error(f"Error analyzing content: {str(e)}")
raise
return results
def _apply_sae(
self,
activations: torch.Tensor,
sae_params: Dict[str, torch.Tensor],
threshold: float
) -> torch.Tensor:
"""Apply SAE to get feature activations"""
pre_acts = activations @ sae_params['W_enc'] + sae_params['b_enc']
mask = pre_acts > sae_params['threshold']
acts = mask * torch.nn.functional.relu(pre_acts)
return acts
def _interpret_activation(
self,
activations: torch.Tensor,
feature: MarketingFeature
) -> str:
"""Interpret activation patterns for a feature"""
mean_activation = float(activations.mean())
if mean_activation > 0.8:
return f"Very strong presence of {feature.name.lower()}"
elif mean_activation > 0.5:
return f"Moderate presence of {feature.name.lower()}"
else:
return f"Limited presence of {feature.name.lower()}"
def _generate_recommendations(self, results: Dict) -> List[str]:
"""Generate content recommendations based on analysis"""
recommendations = []
# Analyze technical complexity
tech_score = np.mean([
f['activation_score'] for f in results['features'].values()
if f['category'] == 'technical'
])
if tech_score > 0.8:
recommendations.append(
"Consider simplifying technical language for broader audience"
)
elif tech_score < 0.3:
recommendations.append(
"Could benefit from more specific technical details"
)
# Add more recommendation logic as needed
return recommendations
def create_gradio_interface():
"""Create Gradio interface for marketing analysis"""
try:
analyzer = MarketingAnalyzer()
except Exception as e:
logger.error(f"Failed to initialize analyzer: {str(e)}")
# Provide a more graceful fallback or error message in the interface
return gr.Interface(
fn=lambda x: "Error: Failed to initialize model. Please check authentication.",
inputs=gr.Textbox(),
outputs=gr.Textbox(),
title="Marketing Content Analyzer (Error)",
description="Failed to initialize. Please check if HF_TOKEN is properly set."
)
def analyze(text):
results = analyzer.analyze_content(text)
# Format results for display
output = "Content Analysis Results\n\n"
# Overall category scores
output += "Category Scores:\n"
for category, features in results['categories'].items():
avg_score = np.mean([f['activation_score'] for f in features])
output += f"{category.title()}: {avg_score:.2f}\n"
# Feature details
output += "\nFeature Details:\n"
for feature_id, feature in results['features'].items():
output += f"\n{feature['name']}:\n"
output += f"Score: {feature['activation_score']:.2f}\n"
output += f"Interpretation: {feature['interpretation']}\n"
# Recommendations
output += "\nRecommendations:\n"
for rec in results['recommendations']:
output += f"- {rec}\n"
return output
iface = gr.Interface(
fn=analyze,
inputs=gr.Textbox(
lines=5,
placeholder="Enter your marketing content here..."
),
outputs=gr.Textbox(),
title="Marketing Content Analyzer",
description="Analyze your marketing content using Gemma Scope's neural features",
examples=[
["WordLift is an AI-powered SEO tool"],
["Our advanced machine learning algorithms optimize your content"],
["Simple and effective website optimization"]
]
)
return iface
if __name__ == "__main__":
iface = create_gradio_interface()
iface.launch()