cyberandy commited on
Commit
7e6371a
·
1 Parent(s): 01d3df7
Files changed (1) hide show
  1. app.py +109 -92
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  import torch
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from huggingface_hub import hf_hub_download
5
  import numpy as np
@@ -7,7 +8,6 @@ from dataclasses import dataclass
7
  from typing import List, Dict, Optional
8
  import logging
9
 
10
- # Initialize logging
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
13
 
@@ -25,7 +25,7 @@ class MarketingFeature:
25
  threshold: float = 0.1
26
 
27
 
28
- # Define marketing-relevant features from Gemma Scope
29
  MARKETING_FEATURES = [
30
  MarketingFeature(
31
  feature_id=35,
@@ -33,7 +33,7 @@ MARKETING_FEATURES = [
33
  category="technical",
34
  description="Detects technical and specialized terminology",
35
  interpretation_guide="High activation indicates strong technical focus",
36
- layer=6, # Adjusted for Gemma-2B structure
37
  ),
38
  MarketingFeature(
39
  feature_id=6680,
@@ -41,7 +41,7 @@ MARKETING_FEATURES = [
41
  category="technical",
42
  description="Identifies complex technical concepts",
43
  interpretation_guide="Consider simplifying language if activation is too high",
44
- layer=6, # Adjusted for Gemma-2B structure
45
  ),
46
  MarketingFeature(
47
  feature_id=2,
@@ -49,57 +49,77 @@ MARKETING_FEATURES = [
49
  category="seo",
50
  description="Identifies potential SEO keywords",
51
  interpretation_guide="High activation suggests strong SEO potential",
52
- layer=6, # Adjusted for Gemma-2B structure
53
  ),
54
  ]
55
 
56
 
57
- class MarketingAnalyzer:
58
- """Main class for analyzing marketing content using Gemma Scope"""
 
 
 
 
 
 
 
 
 
 
 
 
59
 
 
 
 
 
 
 
 
 
 
 
60
  def __init__(self):
61
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62
- # Store model size as instance variable
63
- self.model_size = "2b"
64
  self._initialize_model()
65
  self._load_saes()
66
 
67
  def _initialize_model(self):
68
- """Initialize Gemma model and tokenizer"""
69
  try:
70
- model_name = f"google/gemma-{self.model_size}"
71
-
72
- # Initialize model and tokenizer with token from environment
73
  self.model = AutoModelForCausalLM.from_pretrained(
74
- model_name, device_map="auto"
75
  )
76
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
77
-
78
  self.model.eval()
79
- logger.info(f"Initialized model: {model_name}")
80
-
81
  except Exception as e:
82
  logger.error(f"Error initializing model: {str(e)}")
83
  raise
84
 
85
  def _load_saes(self):
86
- """Load relevant SAEs from Gemma Scope"""
87
  self.saes = {}
88
  for feature in MARKETING_FEATURES:
89
  try:
90
- # Load SAE parameters for each feature
91
  path = hf_hub_download(
92
- repo_id=f"google/gemma-scope-{self.model_size}-pt-res",
93
  filename=f"layer_{feature.layer}/width_16k/average_l0_71/params.npz",
 
94
  )
95
  params = np.load(path)
96
- self.saes[feature.feature_id] = {
97
- "params": {
98
- k: torch.from_numpy(v).to(self.device)
99
- for k, v in params.items()
100
- },
101
- "feature": feature,
 
 
 
102
  }
 
 
 
103
  logger.info(f"Loaded SAE for feature {feature.feature_id}")
104
  except Exception as e:
105
  logger.error(
@@ -107,8 +127,23 @@ class MarketingAnalyzer:
107
  )
108
  continue
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  def analyze_content(self, text: str) -> Dict:
111
- """Analyze marketing content using loaded SAEs"""
112
  results = {
113
  "text": text,
114
  "features": {},
@@ -117,26 +152,22 @@ class MarketingAnalyzer:
117
  }
118
 
119
  try:
120
- # Get model activations
121
- inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
122
- with torch.no_grad():
123
- outputs = self.model(**inputs, output_hidden_states=True)
124
-
125
- # Analyze each feature
126
  for feature_id, sae_data in self.saes.items():
127
  feature = sae_data["feature"]
128
- layer_output = outputs.hidden_states[feature.layer]
129
 
130
- # Apply SAE
131
- activations = self._apply_sae(
132
- layer_output, sae_data["params"], feature.threshold
133
- )
 
 
134
 
135
- # Skip BOS token and handle empty activations
136
- activations = activations[:, 1:] # Skip BOS token
137
- if activations.numel() > 0:
138
- mean_activation = float(activations.mean())
139
- max_activation = float(activations.max())
140
  else:
141
  mean_activation = 0.0
142
  max_activation = 0.0
@@ -154,12 +185,10 @@ class MarketingAnalyzer:
154
 
155
  results["features"][feature_id] = feature_result
156
 
157
- # Aggregate by category
158
  if feature.category not in results["categories"]:
159
  results["categories"][feature.category] = []
160
  results["categories"][feature.category].append(feature_result)
161
 
162
- # Generate recommendations
163
  results["recommendations"] = self._generate_recommendations(results)
164
 
165
  except Exception as e:
@@ -168,22 +197,9 @@ class MarketingAnalyzer:
168
 
169
  return results
170
 
171
- def _apply_sae(
172
- self,
173
- activations: torch.Tensor,
174
- sae_params: Dict[str, torch.Tensor],
175
- threshold: float,
176
- ) -> torch.Tensor:
177
- """Apply SAE to get feature activations"""
178
- pre_acts = activations @ sae_params["W_enc"] + sae_params["b_enc"]
179
- mask = pre_acts > sae_params["threshold"]
180
- acts = mask * torch.nn.functional.relu(pre_acts)
181
- return acts
182
-
183
  def _interpret_activation(
184
  self, activation: float, feature: MarketingFeature
185
  ) -> str:
186
- """Interpret activation patterns for a feature"""
187
  if activation > 0.8:
188
  return f"Very strong presence of {feature.name.lower()}"
189
  elif activation > 0.5:
@@ -192,19 +208,15 @@ class MarketingAnalyzer:
192
  return f"Limited presence of {feature.name.lower()}"
193
 
194
  def _generate_recommendations(self, results: Dict) -> List[str]:
195
- """Generate content recommendations based on analysis"""
196
  recommendations = []
197
 
198
  try:
199
- # Get technical features
200
  tech_features = [
201
  f for f in results["features"].values() if f["category"] == "technical"
202
  ]
203
 
204
- # Calculate average technical score if we have features
205
  if tech_features:
206
  tech_score = np.mean([f["activation_score"] for f in tech_features])
207
-
208
  if tech_score > 0.8:
209
  recommendations.append(
210
  "Consider simplifying technical language for broader audience"
@@ -220,7 +232,6 @@ class MarketingAnalyzer:
220
 
221
 
222
  def create_gradio_interface():
223
- """Create Gradio interface for marketing analysis"""
224
  try:
225
  analyzer = MarketingAnalyzer()
226
  except Exception as e:
@@ -230,30 +241,26 @@ def create_gradio_interface():
230
  inputs=gr.Textbox(),
231
  outputs=gr.Textbox(),
232
  title="Marketing Content Analyzer (Error)",
233
- description="Failed to initialize. Please check if HF_TOKEN is properly set.",
234
  )
235
 
236
  def analyze(text):
237
  results = analyzer.analyze_content(text)
238
 
239
- # Format results for display
240
  output = "Content Analysis Results\n\n"
241
 
242
- # Overall category scores
243
  output += "Category Scores:\n"
244
  for category, features in results["categories"].items():
245
- if features: # Check if we have features for this category
246
  avg_score = np.mean([f["activation_score"] for f in features])
247
  output += f"{category.title()}: {avg_score:.2f}\n"
248
 
249
- # Feature details
250
  output += "\nFeature Details:\n"
251
  for feature_id, feature in results["features"].items():
252
  output += f"\n{feature['name']}:\n"
253
  output += f"Score: {feature['activation_score']:.2f}\n"
254
  output += f"Interpretation: {feature['interpretation']}\n"
255
 
256
- # Recommendations
257
  if results["recommendations"]:
258
  output += "\nRecommendations:\n"
259
  for rec in results["recommendations"]:
@@ -261,28 +268,38 @@ def create_gradio_interface():
261
 
262
  return output
263
 
264
- # Create interface with custom theming
265
- custom_theme = gr.themes.Soft(
266
- primary_hue="indigo", secondary_hue="blue", neutral_hue="gray"
267
- )
268
-
269
- interface = gr.Interface(
270
- fn=analyze,
271
- inputs=gr.Textbox(
272
- lines=5,
273
- placeholder="Enter your marketing content here...",
274
- label="Marketing Content",
275
- ),
276
- outputs=gr.Textbox(label="Analysis Results"),
277
- title="Marketing Content Analyzer",
278
- description="Analyze your marketing content using Gemma Scope's neural features",
279
- examples=[
280
- ["WordLift is an AI-powered SEO tool"],
281
- ["Our advanced machine learning algorithms optimize your content"],
282
- ["Simple and effective website optimization"],
283
- ],
284
- theme=custom_theme,
285
- )
 
 
 
 
 
 
 
 
 
 
286
 
287
  return interface
288
 
 
1
  import gradio as gr
2
  import torch
3
+ import torch.nn as nn
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
  from huggingface_hub import hf_hub_download
6
  import numpy as np
 
8
  from typing import List, Dict, Optional
9
  import logging
10
 
 
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
13
 
 
25
  threshold: float = 0.1
26
 
27
 
28
+ # Define relevant features
29
  MARKETING_FEATURES = [
30
  MarketingFeature(
31
  feature_id=35,
 
33
  category="technical",
34
  description="Detects technical and specialized terminology",
35
  interpretation_guide="High activation indicates strong technical focus",
36
+ layer=20,
37
  ),
38
  MarketingFeature(
39
  feature_id=6680,
 
41
  category="technical",
42
  description="Identifies complex technical concepts",
43
  interpretation_guide="Consider simplifying language if activation is too high",
44
+ layer=20,
45
  ),
46
  MarketingFeature(
47
  feature_id=2,
 
49
  category="seo",
50
  description="Identifies potential SEO keywords",
51
  interpretation_guide="High activation suggests strong SEO potential",
52
+ layer=20,
53
  ),
54
  ]
55
 
56
 
57
+ class JumpReLUSAE(nn.Module):
58
+ def __init__(self, d_model, d_sae):
59
+ super().__init__()
60
+ self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))
61
+ self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))
62
+ self.threshold = nn.Parameter(torch.zeros(d_sae))
63
+ self.b_enc = nn.Parameter(torch.zeros(d_sae))
64
+ self.b_dec = nn.Parameter(torch.zeros(d_model))
65
+
66
+ def encode(self, input_acts):
67
+ pre_acts = input_acts @ self.W_enc + self.b_enc
68
+ mask = pre_acts > self.threshold
69
+ acts = mask * torch.nn.functional.relu(pre_acts)
70
+ return acts
71
 
72
+ def decode(self, acts):
73
+ return acts @ self.W_dec + self.b_dec
74
+
75
+ def forward(self, acts):
76
+ acts = self.encode(acts)
77
+ recon = self.decode(acts)
78
+ return recon
79
+
80
+
81
+ class MarketingAnalyzer:
82
  def __init__(self):
83
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
84
+ torch.set_grad_enabled(False) # Avoid memory issues
 
85
  self._initialize_model()
86
  self._load_saes()
87
 
88
  def _initialize_model(self):
 
89
  try:
 
 
 
90
  self.model = AutoModelForCausalLM.from_pretrained(
91
+ "google/gemma-2-2b", device_map="auto"
92
  )
93
+ self.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
 
94
  self.model.eval()
95
+ logger.info("Model initialized successfully")
 
96
  except Exception as e:
97
  logger.error(f"Error initializing model: {str(e)}")
98
  raise
99
 
100
  def _load_saes(self):
 
101
  self.saes = {}
102
  for feature in MARKETING_FEATURES:
103
  try:
 
104
  path = hf_hub_download(
105
+ repo_id="google/gemma-scope-2b-pt-res",
106
  filename=f"layer_{feature.layer}/width_16k/average_l0_71/params.npz",
107
+ force_download=False,
108
  )
109
  params = np.load(path)
110
+
111
+ # Create SAE
112
+ d_model = params["W_enc"].shape[0]
113
+ d_sae = params["W_enc"].shape[1]
114
+ sae = JumpReLUSAE(d_model, d_sae).to(self.device)
115
+
116
+ # Load parameters
117
+ sae_params = {
118
+ k: torch.from_numpy(v).to(self.device) for k, v in params.items()
119
  }
120
+ sae.load_state_dict(sae_params)
121
+
122
+ self.saes[feature.feature_id] = {"sae": sae, "feature": feature}
123
  logger.info(f"Loaded SAE for feature {feature.feature_id}")
124
  except Exception as e:
125
  logger.error(
 
127
  )
128
  continue
129
 
130
+ def _gather_activations(self, text: str, layer: int):
131
+ inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
132
+ target_act = None
133
+
134
+ def hook(mod, inputs, outputs):
135
+ nonlocal target_act
136
+ target_act = outputs[0]
137
+ return outputs
138
+
139
+ handle = self.model.model.layers[layer].register_forward_hook(hook)
140
+ with torch.no_grad():
141
+ _ = self.model(**inputs)
142
+ handle.remove()
143
+
144
+ return target_act, inputs
145
+
146
  def analyze_content(self, text: str) -> Dict:
 
147
  results = {
148
  "text": text,
149
  "features": {},
 
152
  }
153
 
154
  try:
155
+ # Get activations for each feature
 
 
 
 
 
156
  for feature_id, sae_data in self.saes.items():
157
  feature = sae_data["feature"]
158
+ sae = sae_data["sae"]
159
 
160
+ # Get layer activations
161
+ activations, inputs = self._gather_activations(text, feature.layer)
162
+
163
+ # Skip BOS token and get activations
164
+ sae_acts = sae.encode(activations.to(torch.float32))
165
+ sae_acts = sae_acts[:, 1:] # Skip BOS token
166
 
167
+ # Calculate metrics
168
+ if sae_acts.numel() > 0:
169
+ mean_activation = float(sae_acts.mean())
170
+ max_activation = float(sae_acts.max())
 
171
  else:
172
  mean_activation = 0.0
173
  max_activation = 0.0
 
185
 
186
  results["features"][feature_id] = feature_result
187
 
 
188
  if feature.category not in results["categories"]:
189
  results["categories"][feature.category] = []
190
  results["categories"][feature.category].append(feature_result)
191
 
 
192
  results["recommendations"] = self._generate_recommendations(results)
193
 
194
  except Exception as e:
 
197
 
198
  return results
199
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  def _interpret_activation(
201
  self, activation: float, feature: MarketingFeature
202
  ) -> str:
 
203
  if activation > 0.8:
204
  return f"Very strong presence of {feature.name.lower()}"
205
  elif activation > 0.5:
 
208
  return f"Limited presence of {feature.name.lower()}"
209
 
210
  def _generate_recommendations(self, results: Dict) -> List[str]:
 
211
  recommendations = []
212
 
213
  try:
 
214
  tech_features = [
215
  f for f in results["features"].values() if f["category"] == "technical"
216
  ]
217
 
 
218
  if tech_features:
219
  tech_score = np.mean([f["activation_score"] for f in tech_features])
 
220
  if tech_score > 0.8:
221
  recommendations.append(
222
  "Consider simplifying technical language for broader audience"
 
232
 
233
 
234
  def create_gradio_interface():
 
235
  try:
236
  analyzer = MarketingAnalyzer()
237
  except Exception as e:
 
241
  inputs=gr.Textbox(),
242
  outputs=gr.Textbox(),
243
  title="Marketing Content Analyzer (Error)",
244
+ description="Failed to initialize.",
245
  )
246
 
247
  def analyze(text):
248
  results = analyzer.analyze_content(text)
249
 
 
250
  output = "Content Analysis Results\n\n"
251
 
 
252
  output += "Category Scores:\n"
253
  for category, features in results["categories"].items():
254
+ if features:
255
  avg_score = np.mean([f["activation_score"] for f in features])
256
  output += f"{category.title()}: {avg_score:.2f}\n"
257
 
 
258
  output += "\nFeature Details:\n"
259
  for feature_id, feature in results["features"].items():
260
  output += f"\n{feature['name']}:\n"
261
  output += f"Score: {feature['activation_score']:.2f}\n"
262
  output += f"Interpretation: {feature['interpretation']}\n"
263
 
 
264
  if results["recommendations"]:
265
  output += "\nRecommendations:\n"
266
  for rec in results["recommendations"]:
 
268
 
269
  return output
270
 
271
+ with gr.Blocks(
272
+ theme=gr.themes.Default(
273
+ font=[gr.themes.GoogleFont("Open Sans"), "Arial", "sans-serif"],
274
+ primary_hue="indigo",
275
+ secondary_hue="blue",
276
+ neutral_hue="gray",
277
+ )
278
+ ) as interface:
279
+ gr.Markdown("# Marketing Content Analyzer")
280
+ gr.Markdown(
281
+ "Analyze your marketing content using Gemma Scope's neural features"
282
+ )
283
+
284
+ with gr.Row():
285
+ input_text = gr.Textbox(
286
+ lines=5,
287
+ placeholder="Enter your marketing content here...",
288
+ label="Marketing Content",
289
+ )
290
+ output_text = gr.Textbox(label="Analysis Results")
291
+
292
+ analyze_btn = gr.Button("Analyze", variant="primary")
293
+ analyze_btn.click(fn=analyze, inputs=input_text, outputs=output_text)
294
+
295
+ gr.Examples(
296
+ examples=[
297
+ "WordLift is an AI-powered SEO tool",
298
+ "Our advanced machine learning algorithms optimize your content",
299
+ "Simple and effective website optimization",
300
+ ],
301
+ inputs=input_text,
302
+ )
303
 
304
  return interface
305