Spaces:
Sleeping
Sleeping
File size: 9,930 Bytes
f85532f deaf693 f85532f deaf693 f85532f deaf693 f85532f deaf693 f85532f deaf693 f85532f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 |
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download
import numpy as np
from dataclasses import dataclass
from typing import List, Dict, Optional
import logging
# Initialize logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class MarketingFeature:
"""Structure to hold marketing-relevant feature information"""
feature_id: int
name: str
category: str
description: str
interpretation_guide: str
layer: int
threshold: float = 0.1
# Define marketing-relevant features from Gemma Scope
MARKETING_FEATURES = [
MarketingFeature(
feature_id=35,
name="Technical Term Detector",
category="technical",
description="Detects technical and specialized terminology",
interpretation_guide="High activation indicates strong technical focus",
layer=20
),
MarketingFeature(
feature_id=6680,
name="Compound Technical Terms",
category="technical",
description="Identifies complex technical concepts",
interpretation_guide="Consider simplifying language if activation is too high",
layer=20
),
MarketingFeature(
feature_id=2,
name="SEO Keyword Detector",
category="seo",
description="Identifies potential SEO keywords",
interpretation_guide="High activation suggests strong SEO potential",
layer=20
),
# Add more relevant features as we discover them
]
class MarketingAnalyzer:
"""Main class for analyzing marketing content using Gemma Scope"""
def __init__(self, model_size: str = "2b"):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self._initialize_model(model_size)
self._load_saes()
def _initialize_model(self, model_size: str):
"""Initialize Gemma model and tokenizer"""
try:
import os
model_name = f"google/gemma-{model_size}"
# Access HF token from environment variable
hf_token = os.environ.get('HF_TOKEN')
if not hf_token:
logger.warning("HF_TOKEN not found in environment variables")
# Initialize model and tokenizer with token
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
token=hf_token,
device_map='auto' # Automatically handle device placement
)
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
token=hf_token
)
self.model.eval() # Set to evaluation mode
logger.info(f"Initialized model: {model_name}")
except Exception as e:
logger.error(f"Error initializing model: {str(e)}")
raise
def _load_saes(self):
"""Load relevant SAEs from Gemma Scope"""
self.saes = {}
for feature in MARKETING_FEATURES:
try:
# Load SAE parameters for each feature
path = hf_hub_download(
repo_id=f"google/gemma-scope-{self.model_size}-pt-res",
filename=f"layer_{feature.layer}/width_16k/average_l0_71/params.npz"
)
params = np.load(path)
self.saes[feature.feature_id] = {
'params': {k: torch.from_numpy(v).cuda() for k, v in params.items()},
'feature': feature
}
logger.info(f"Loaded SAE for feature {feature.feature_id}")
except Exception as e:
logger.error(f"Error loading SAE for feature {feature.feature_id}: {str(e)}")
continue
def analyze_content(self, text: str) -> Dict:
"""Analyze marketing content using loaded SAEs"""
results = {
'text': text,
'features': {},
'categories': {},
'recommendations': []
}
try:
# Get model activations
inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model(**inputs, output_hidden_states=True)
# Analyze each feature
for feature_id, sae_data in self.saes.items():
feature = sae_data['feature']
layer_output = outputs.hidden_states[feature.layer]
# Apply SAE
activations = self._apply_sae(
layer_output,
sae_data['params'],
feature.threshold
)
# Record results
feature_result = {
'name': feature.name,
'category': feature.category,
'activation_score': float(activations.mean()),
'max_activation': float(activations.max()),
'interpretation': self._interpret_activation(
activations,
feature
)
}
results['features'][feature_id] = feature_result
# Aggregate by category
if feature.category not in results['categories']:
results['categories'][feature.category] = []
results['categories'][feature.category].append(feature_result)
# Generate recommendations
results['recommendations'] = self._generate_recommendations(results)
except Exception as e:
logger.error(f"Error analyzing content: {str(e)}")
raise
return results
def _apply_sae(
self,
activations: torch.Tensor,
sae_params: Dict[str, torch.Tensor],
threshold: float
) -> torch.Tensor:
"""Apply SAE to get feature activations"""
pre_acts = activations @ sae_params['W_enc'] + sae_params['b_enc']
mask = pre_acts > sae_params['threshold']
acts = mask * torch.nn.functional.relu(pre_acts)
return acts
def _interpret_activation(
self,
activations: torch.Tensor,
feature: MarketingFeature
) -> str:
"""Interpret activation patterns for a feature"""
mean_activation = float(activations.mean())
if mean_activation > 0.8:
return f"Very strong presence of {feature.name.lower()}"
elif mean_activation > 0.5:
return f"Moderate presence of {feature.name.lower()}"
else:
return f"Limited presence of {feature.name.lower()}"
def _generate_recommendations(self, results: Dict) -> List[str]:
"""Generate content recommendations based on analysis"""
recommendations = []
# Analyze technical complexity
tech_score = np.mean([
f['activation_score'] for f in results['features'].values()
if f['category'] == 'technical'
])
if tech_score > 0.8:
recommendations.append(
"Consider simplifying technical language for broader audience"
)
elif tech_score < 0.3:
recommendations.append(
"Could benefit from more specific technical details"
)
# Add more recommendation logic as needed
return recommendations
def create_gradio_interface():
"""Create Gradio interface for marketing analysis"""
try:
analyzer = MarketingAnalyzer()
except Exception as e:
logger.error(f"Failed to initialize analyzer: {str(e)}")
# Provide a more graceful fallback or error message in the interface
return gr.Interface(
fn=lambda x: "Error: Failed to initialize model. Please check authentication.",
inputs=gr.Textbox(),
outputs=gr.Textbox(),
title="Marketing Content Analyzer (Error)",
description="Failed to initialize. Please check if HF_TOKEN is properly set."
)
def analyze(text):
results = analyzer.analyze_content(text)
# Format results for display
output = "Content Analysis Results\n\n"
# Overall category scores
output += "Category Scores:\n"
for category, features in results['categories'].items():
avg_score = np.mean([f['activation_score'] for f in features])
output += f"{category.title()}: {avg_score:.2f}\n"
# Feature details
output += "\nFeature Details:\n"
for feature_id, feature in results['features'].items():
output += f"\n{feature['name']}:\n"
output += f"Score: {feature['activation_score']:.2f}\n"
output += f"Interpretation: {feature['interpretation']}\n"
# Recommendations
output += "\nRecommendations:\n"
for rec in results['recommendations']:
output += f"- {rec}\n"
return output
iface = gr.Interface(
fn=analyze,
inputs=gr.Textbox(
lines=5,
placeholder="Enter your marketing content here..."
),
outputs=gr.Textbox(),
title="Marketing Content Analyzer",
description="Analyze your marketing content using Gemma Scope's neural features",
examples=[
["WordLift is an AI-powered SEO tool"],
["Our advanced machine learning algorithms optimize your content"],
["Simple and effective website optimization"]
]
)
return iface
if __name__ == "__main__":
iface = create_gradio_interface()
iface.launch() |