Spaces:
Sleeping
Sleeping
update
Browse files
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
|
|
3 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
4 |
from huggingface_hub import hf_hub_download
|
5 |
import numpy as np
|
@@ -7,7 +8,6 @@ from dataclasses import dataclass
|
|
7 |
from typing import List, Dict, Optional
|
8 |
import logging
|
9 |
|
10 |
-
# Initialize logging
|
11 |
logging.basicConfig(level=logging.INFO)
|
12 |
logger = logging.getLogger(__name__)
|
13 |
|
@@ -25,7 +25,7 @@ class MarketingFeature:
|
|
25 |
threshold: float = 0.1
|
26 |
|
27 |
|
28 |
-
# Define
|
29 |
MARKETING_FEATURES = [
|
30 |
MarketingFeature(
|
31 |
feature_id=35,
|
@@ -33,7 +33,7 @@ MARKETING_FEATURES = [
|
|
33 |
category="technical",
|
34 |
description="Detects technical and specialized terminology",
|
35 |
interpretation_guide="High activation indicates strong technical focus",
|
36 |
-
layer=
|
37 |
),
|
38 |
MarketingFeature(
|
39 |
feature_id=6680,
|
@@ -41,7 +41,7 @@ MARKETING_FEATURES = [
|
|
41 |
category="technical",
|
42 |
description="Identifies complex technical concepts",
|
43 |
interpretation_guide="Consider simplifying language if activation is too high",
|
44 |
-
layer=
|
45 |
),
|
46 |
MarketingFeature(
|
47 |
feature_id=2,
|
@@ -49,57 +49,77 @@ MARKETING_FEATURES = [
|
|
49 |
category="seo",
|
50 |
description="Identifies potential SEO keywords",
|
51 |
interpretation_guide="High activation suggests strong SEO potential",
|
52 |
-
layer=
|
53 |
),
|
54 |
]
|
55 |
|
56 |
|
57 |
-
class
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
def __init__(self):
|
61 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
62 |
-
#
|
63 |
-
self.model_size = "2b"
|
64 |
self._initialize_model()
|
65 |
self._load_saes()
|
66 |
|
67 |
def _initialize_model(self):
|
68 |
-
"""Initialize Gemma model and tokenizer"""
|
69 |
try:
|
70 |
-
model_name = f"google/gemma-{self.model_size}"
|
71 |
-
|
72 |
-
# Initialize model and tokenizer with token from environment
|
73 |
self.model = AutoModelForCausalLM.from_pretrained(
|
74 |
-
|
75 |
)
|
76 |
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
77 |
-
|
78 |
self.model.eval()
|
79 |
-
logger.info(
|
80 |
-
|
81 |
except Exception as e:
|
82 |
logger.error(f"Error initializing model: {str(e)}")
|
83 |
raise
|
84 |
|
85 |
def _load_saes(self):
|
86 |
-
"""Load relevant SAEs from Gemma Scope"""
|
87 |
self.saes = {}
|
88 |
for feature in MARKETING_FEATURES:
|
89 |
try:
|
90 |
-
# Load SAE parameters for each feature
|
91 |
path = hf_hub_download(
|
92 |
-
repo_id=
|
93 |
filename=f"layer_{feature.layer}/width_16k/average_l0_71/params.npz",
|
|
|
94 |
)
|
95 |
params = np.load(path)
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
|
|
|
|
|
|
102 |
}
|
|
|
|
|
|
|
103 |
logger.info(f"Loaded SAE for feature {feature.feature_id}")
|
104 |
except Exception as e:
|
105 |
logger.error(
|
@@ -107,8 +127,23 @@ class MarketingAnalyzer:
|
|
107 |
)
|
108 |
continue
|
109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
def analyze_content(self, text: str) -> Dict:
|
111 |
-
"""Analyze marketing content using loaded SAEs"""
|
112 |
results = {
|
113 |
"text": text,
|
114 |
"features": {},
|
@@ -117,26 +152,22 @@ class MarketingAnalyzer:
|
|
117 |
}
|
118 |
|
119 |
try:
|
120 |
-
# Get
|
121 |
-
inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
|
122 |
-
with torch.no_grad():
|
123 |
-
outputs = self.model(**inputs, output_hidden_states=True)
|
124 |
-
|
125 |
-
# Analyze each feature
|
126 |
for feature_id, sae_data in self.saes.items():
|
127 |
feature = sae_data["feature"]
|
128 |
-
|
129 |
|
130 |
-
#
|
131 |
-
activations = self.
|
132 |
-
|
133 |
-
|
|
|
|
|
134 |
|
135 |
-
#
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
max_activation = float(activations.max())
|
140 |
else:
|
141 |
mean_activation = 0.0
|
142 |
max_activation = 0.0
|
@@ -154,12 +185,10 @@ class MarketingAnalyzer:
|
|
154 |
|
155 |
results["features"][feature_id] = feature_result
|
156 |
|
157 |
-
# Aggregate by category
|
158 |
if feature.category not in results["categories"]:
|
159 |
results["categories"][feature.category] = []
|
160 |
results["categories"][feature.category].append(feature_result)
|
161 |
|
162 |
-
# Generate recommendations
|
163 |
results["recommendations"] = self._generate_recommendations(results)
|
164 |
|
165 |
except Exception as e:
|
@@ -168,22 +197,9 @@ class MarketingAnalyzer:
|
|
168 |
|
169 |
return results
|
170 |
|
171 |
-
def _apply_sae(
|
172 |
-
self,
|
173 |
-
activations: torch.Tensor,
|
174 |
-
sae_params: Dict[str, torch.Tensor],
|
175 |
-
threshold: float,
|
176 |
-
) -> torch.Tensor:
|
177 |
-
"""Apply SAE to get feature activations"""
|
178 |
-
pre_acts = activations @ sae_params["W_enc"] + sae_params["b_enc"]
|
179 |
-
mask = pre_acts > sae_params["threshold"]
|
180 |
-
acts = mask * torch.nn.functional.relu(pre_acts)
|
181 |
-
return acts
|
182 |
-
|
183 |
def _interpret_activation(
|
184 |
self, activation: float, feature: MarketingFeature
|
185 |
) -> str:
|
186 |
-
"""Interpret activation patterns for a feature"""
|
187 |
if activation > 0.8:
|
188 |
return f"Very strong presence of {feature.name.lower()}"
|
189 |
elif activation > 0.5:
|
@@ -192,19 +208,15 @@ class MarketingAnalyzer:
|
|
192 |
return f"Limited presence of {feature.name.lower()}"
|
193 |
|
194 |
def _generate_recommendations(self, results: Dict) -> List[str]:
|
195 |
-
"""Generate content recommendations based on analysis"""
|
196 |
recommendations = []
|
197 |
|
198 |
try:
|
199 |
-
# Get technical features
|
200 |
tech_features = [
|
201 |
f for f in results["features"].values() if f["category"] == "technical"
|
202 |
]
|
203 |
|
204 |
-
# Calculate average technical score if we have features
|
205 |
if tech_features:
|
206 |
tech_score = np.mean([f["activation_score"] for f in tech_features])
|
207 |
-
|
208 |
if tech_score > 0.8:
|
209 |
recommendations.append(
|
210 |
"Consider simplifying technical language for broader audience"
|
@@ -220,7 +232,6 @@ class MarketingAnalyzer:
|
|
220 |
|
221 |
|
222 |
def create_gradio_interface():
|
223 |
-
"""Create Gradio interface for marketing analysis"""
|
224 |
try:
|
225 |
analyzer = MarketingAnalyzer()
|
226 |
except Exception as e:
|
@@ -230,30 +241,26 @@ def create_gradio_interface():
|
|
230 |
inputs=gr.Textbox(),
|
231 |
outputs=gr.Textbox(),
|
232 |
title="Marketing Content Analyzer (Error)",
|
233 |
-
description="Failed to initialize.
|
234 |
)
|
235 |
|
236 |
def analyze(text):
|
237 |
results = analyzer.analyze_content(text)
|
238 |
|
239 |
-
# Format results for display
|
240 |
output = "Content Analysis Results\n\n"
|
241 |
|
242 |
-
# Overall category scores
|
243 |
output += "Category Scores:\n"
|
244 |
for category, features in results["categories"].items():
|
245 |
-
if features:
|
246 |
avg_score = np.mean([f["activation_score"] for f in features])
|
247 |
output += f"{category.title()}: {avg_score:.2f}\n"
|
248 |
|
249 |
-
# Feature details
|
250 |
output += "\nFeature Details:\n"
|
251 |
for feature_id, feature in results["features"].items():
|
252 |
output += f"\n{feature['name']}:\n"
|
253 |
output += f"Score: {feature['activation_score']:.2f}\n"
|
254 |
output += f"Interpretation: {feature['interpretation']}\n"
|
255 |
|
256 |
-
# Recommendations
|
257 |
if results["recommendations"]:
|
258 |
output += "\nRecommendations:\n"
|
259 |
for rec in results["recommendations"]:
|
@@ -261,28 +268,38 @@ def create_gradio_interface():
|
|
261 |
|
262 |
return output
|
263 |
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
)
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
286 |
|
287 |
return interface
|
288 |
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
+
import torch.nn as nn
|
4 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
5 |
from huggingface_hub import hf_hub_download
|
6 |
import numpy as np
|
|
|
8 |
from typing import List, Dict, Optional
|
9 |
import logging
|
10 |
|
|
|
11 |
logging.basicConfig(level=logging.INFO)
|
12 |
logger = logging.getLogger(__name__)
|
13 |
|
|
|
25 |
threshold: float = 0.1
|
26 |
|
27 |
|
28 |
+
# Define relevant features
|
29 |
MARKETING_FEATURES = [
|
30 |
MarketingFeature(
|
31 |
feature_id=35,
|
|
|
33 |
category="technical",
|
34 |
description="Detects technical and specialized terminology",
|
35 |
interpretation_guide="High activation indicates strong technical focus",
|
36 |
+
layer=20,
|
37 |
),
|
38 |
MarketingFeature(
|
39 |
feature_id=6680,
|
|
|
41 |
category="technical",
|
42 |
description="Identifies complex technical concepts",
|
43 |
interpretation_guide="Consider simplifying language if activation is too high",
|
44 |
+
layer=20,
|
45 |
),
|
46 |
MarketingFeature(
|
47 |
feature_id=2,
|
|
|
49 |
category="seo",
|
50 |
description="Identifies potential SEO keywords",
|
51 |
interpretation_guide="High activation suggests strong SEO potential",
|
52 |
+
layer=20,
|
53 |
),
|
54 |
]
|
55 |
|
56 |
|
57 |
+
class JumpReLUSAE(nn.Module):
|
58 |
+
def __init__(self, d_model, d_sae):
|
59 |
+
super().__init__()
|
60 |
+
self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))
|
61 |
+
self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))
|
62 |
+
self.threshold = nn.Parameter(torch.zeros(d_sae))
|
63 |
+
self.b_enc = nn.Parameter(torch.zeros(d_sae))
|
64 |
+
self.b_dec = nn.Parameter(torch.zeros(d_model))
|
65 |
+
|
66 |
+
def encode(self, input_acts):
|
67 |
+
pre_acts = input_acts @ self.W_enc + self.b_enc
|
68 |
+
mask = pre_acts > self.threshold
|
69 |
+
acts = mask * torch.nn.functional.relu(pre_acts)
|
70 |
+
return acts
|
71 |
|
72 |
+
def decode(self, acts):
|
73 |
+
return acts @ self.W_dec + self.b_dec
|
74 |
+
|
75 |
+
def forward(self, acts):
|
76 |
+
acts = self.encode(acts)
|
77 |
+
recon = self.decode(acts)
|
78 |
+
return recon
|
79 |
+
|
80 |
+
|
81 |
+
class MarketingAnalyzer:
|
82 |
def __init__(self):
|
83 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
84 |
+
torch.set_grad_enabled(False) # Avoid memory issues
|
|
|
85 |
self._initialize_model()
|
86 |
self._load_saes()
|
87 |
|
88 |
def _initialize_model(self):
|
|
|
89 |
try:
|
|
|
|
|
|
|
90 |
self.model = AutoModelForCausalLM.from_pretrained(
|
91 |
+
"google/gemma-2-2b", device_map="auto"
|
92 |
)
|
93 |
+
self.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
|
|
|
94 |
self.model.eval()
|
95 |
+
logger.info("Model initialized successfully")
|
|
|
96 |
except Exception as e:
|
97 |
logger.error(f"Error initializing model: {str(e)}")
|
98 |
raise
|
99 |
|
100 |
def _load_saes(self):
|
|
|
101 |
self.saes = {}
|
102 |
for feature in MARKETING_FEATURES:
|
103 |
try:
|
|
|
104 |
path = hf_hub_download(
|
105 |
+
repo_id="google/gemma-scope-2b-pt-res",
|
106 |
filename=f"layer_{feature.layer}/width_16k/average_l0_71/params.npz",
|
107 |
+
force_download=False,
|
108 |
)
|
109 |
params = np.load(path)
|
110 |
+
|
111 |
+
# Create SAE
|
112 |
+
d_model = params["W_enc"].shape[0]
|
113 |
+
d_sae = params["W_enc"].shape[1]
|
114 |
+
sae = JumpReLUSAE(d_model, d_sae).to(self.device)
|
115 |
+
|
116 |
+
# Load parameters
|
117 |
+
sae_params = {
|
118 |
+
k: torch.from_numpy(v).to(self.device) for k, v in params.items()
|
119 |
}
|
120 |
+
sae.load_state_dict(sae_params)
|
121 |
+
|
122 |
+
self.saes[feature.feature_id] = {"sae": sae, "feature": feature}
|
123 |
logger.info(f"Loaded SAE for feature {feature.feature_id}")
|
124 |
except Exception as e:
|
125 |
logger.error(
|
|
|
127 |
)
|
128 |
continue
|
129 |
|
130 |
+
def _gather_activations(self, text: str, layer: int):
|
131 |
+
inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
|
132 |
+
target_act = None
|
133 |
+
|
134 |
+
def hook(mod, inputs, outputs):
|
135 |
+
nonlocal target_act
|
136 |
+
target_act = outputs[0]
|
137 |
+
return outputs
|
138 |
+
|
139 |
+
handle = self.model.model.layers[layer].register_forward_hook(hook)
|
140 |
+
with torch.no_grad():
|
141 |
+
_ = self.model(**inputs)
|
142 |
+
handle.remove()
|
143 |
+
|
144 |
+
return target_act, inputs
|
145 |
+
|
146 |
def analyze_content(self, text: str) -> Dict:
|
|
|
147 |
results = {
|
148 |
"text": text,
|
149 |
"features": {},
|
|
|
152 |
}
|
153 |
|
154 |
try:
|
155 |
+
# Get activations for each feature
|
|
|
|
|
|
|
|
|
|
|
156 |
for feature_id, sae_data in self.saes.items():
|
157 |
feature = sae_data["feature"]
|
158 |
+
sae = sae_data["sae"]
|
159 |
|
160 |
+
# Get layer activations
|
161 |
+
activations, inputs = self._gather_activations(text, feature.layer)
|
162 |
+
|
163 |
+
# Skip BOS token and get activations
|
164 |
+
sae_acts = sae.encode(activations.to(torch.float32))
|
165 |
+
sae_acts = sae_acts[:, 1:] # Skip BOS token
|
166 |
|
167 |
+
# Calculate metrics
|
168 |
+
if sae_acts.numel() > 0:
|
169 |
+
mean_activation = float(sae_acts.mean())
|
170 |
+
max_activation = float(sae_acts.max())
|
|
|
171 |
else:
|
172 |
mean_activation = 0.0
|
173 |
max_activation = 0.0
|
|
|
185 |
|
186 |
results["features"][feature_id] = feature_result
|
187 |
|
|
|
188 |
if feature.category not in results["categories"]:
|
189 |
results["categories"][feature.category] = []
|
190 |
results["categories"][feature.category].append(feature_result)
|
191 |
|
|
|
192 |
results["recommendations"] = self._generate_recommendations(results)
|
193 |
|
194 |
except Exception as e:
|
|
|
197 |
|
198 |
return results
|
199 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
def _interpret_activation(
|
201 |
self, activation: float, feature: MarketingFeature
|
202 |
) -> str:
|
|
|
203 |
if activation > 0.8:
|
204 |
return f"Very strong presence of {feature.name.lower()}"
|
205 |
elif activation > 0.5:
|
|
|
208 |
return f"Limited presence of {feature.name.lower()}"
|
209 |
|
210 |
def _generate_recommendations(self, results: Dict) -> List[str]:
|
|
|
211 |
recommendations = []
|
212 |
|
213 |
try:
|
|
|
214 |
tech_features = [
|
215 |
f for f in results["features"].values() if f["category"] == "technical"
|
216 |
]
|
217 |
|
|
|
218 |
if tech_features:
|
219 |
tech_score = np.mean([f["activation_score"] for f in tech_features])
|
|
|
220 |
if tech_score > 0.8:
|
221 |
recommendations.append(
|
222 |
"Consider simplifying technical language for broader audience"
|
|
|
232 |
|
233 |
|
234 |
def create_gradio_interface():
|
|
|
235 |
try:
|
236 |
analyzer = MarketingAnalyzer()
|
237 |
except Exception as e:
|
|
|
241 |
inputs=gr.Textbox(),
|
242 |
outputs=gr.Textbox(),
|
243 |
title="Marketing Content Analyzer (Error)",
|
244 |
+
description="Failed to initialize.",
|
245 |
)
|
246 |
|
247 |
def analyze(text):
|
248 |
results = analyzer.analyze_content(text)
|
249 |
|
|
|
250 |
output = "Content Analysis Results\n\n"
|
251 |
|
|
|
252 |
output += "Category Scores:\n"
|
253 |
for category, features in results["categories"].items():
|
254 |
+
if features:
|
255 |
avg_score = np.mean([f["activation_score"] for f in features])
|
256 |
output += f"{category.title()}: {avg_score:.2f}\n"
|
257 |
|
|
|
258 |
output += "\nFeature Details:\n"
|
259 |
for feature_id, feature in results["features"].items():
|
260 |
output += f"\n{feature['name']}:\n"
|
261 |
output += f"Score: {feature['activation_score']:.2f}\n"
|
262 |
output += f"Interpretation: {feature['interpretation']}\n"
|
263 |
|
|
|
264 |
if results["recommendations"]:
|
265 |
output += "\nRecommendations:\n"
|
266 |
for rec in results["recommendations"]:
|
|
|
268 |
|
269 |
return output
|
270 |
|
271 |
+
with gr.Blocks(
|
272 |
+
theme=gr.themes.Default(
|
273 |
+
font=[gr.themes.GoogleFont("Open Sans"), "Arial", "sans-serif"],
|
274 |
+
primary_hue="indigo",
|
275 |
+
secondary_hue="blue",
|
276 |
+
neutral_hue="gray",
|
277 |
+
)
|
278 |
+
) as interface:
|
279 |
+
gr.Markdown("# Marketing Content Analyzer")
|
280 |
+
gr.Markdown(
|
281 |
+
"Analyze your marketing content using Gemma Scope's neural features"
|
282 |
+
)
|
283 |
+
|
284 |
+
with gr.Row():
|
285 |
+
input_text = gr.Textbox(
|
286 |
+
lines=5,
|
287 |
+
placeholder="Enter your marketing content here...",
|
288 |
+
label="Marketing Content",
|
289 |
+
)
|
290 |
+
output_text = gr.Textbox(label="Analysis Results")
|
291 |
+
|
292 |
+
analyze_btn = gr.Button("Analyze", variant="primary")
|
293 |
+
analyze_btn.click(fn=analyze, inputs=input_text, outputs=output_text)
|
294 |
+
|
295 |
+
gr.Examples(
|
296 |
+
examples=[
|
297 |
+
"WordLift is an AI-powered SEO tool",
|
298 |
+
"Our advanced machine learning algorithms optimize your content",
|
299 |
+
"Simple and effective website optimization",
|
300 |
+
],
|
301 |
+
inputs=input_text,
|
302 |
+
)
|
303 |
|
304 |
return interface
|
305 |
|