cyberandy commited on
Commit
e7c964f
·
1 Parent(s): e78ab36
Files changed (1) hide show
  1. app.py +104 -87
app.py CHANGED
@@ -83,7 +83,6 @@ class MarketingAnalyzer:
83
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
84
  torch.set_grad_enabled(False) # Avoid memory issues
85
  self._initialize_model()
86
- self._load_saes()
87
 
88
  def _initialize_model(self):
89
  try:
@@ -97,35 +96,31 @@ class MarketingAnalyzer:
97
  logger.error(f"Error initializing model: {str(e)}")
98
  raise
99
 
100
- def _load_saes(self):
101
- self.saes = {}
102
- for feature in MARKETING_FEATURES:
103
- try:
104
- path = hf_hub_download(
105
- repo_id="google/gemma-scope-2b-pt-res",
106
- filename=f"layer_{feature.layer}/width_16k/average_l0_71/params.npz",
107
- force_download=False,
108
- )
109
- params = np.load(path)
110
 
111
- # Create SAE
112
- d_model = params["W_enc"].shape[0]
113
- d_sae = params["W_enc"].shape[1]
114
- sae = JumpReLUSAE(d_model, d_sae).to(self.device)
115
 
116
- # Load parameters
117
- sae_params = {
118
- k: torch.from_numpy(v).to(self.device) for k, v in params.items()
119
- }
120
- sae.load_state_dict(sae_params)
121
 
122
- self.saes[feature.feature_id] = {"sae": sae, "feature": feature}
123
- logger.info(f"Loaded SAE for feature {feature.feature_id}")
124
- except Exception as e:
125
- logger.error(
126
- f"Error loading SAE for feature {feature.feature_id}: {str(e)}"
127
- )
128
- continue
129
 
130
  def _gather_activations(self, text: str, layer: int):
131
  inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
@@ -143,7 +138,23 @@ class MarketingAnalyzer:
143
 
144
  return target_act, inputs
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  def analyze_content(self, text: str) -> Dict:
 
147
  results = {
148
  "text": text,
149
  "features": {},
@@ -152,44 +163,74 @@ class MarketingAnalyzer:
152
  }
153
 
154
  try:
155
- # Get activations for each feature
156
- for feature_id, sae_data in self.saes.items():
157
- feature = sae_data["feature"]
158
- sae = sae_data["sae"]
159
-
160
- # Get layer activations
161
- activations, inputs = self._gather_activations(text, feature.layer)
162
-
163
- # Skip BOS token and get activations
164
- sae_acts = sae.encode(activations.to(torch.float32))
165
- sae_acts = sae_acts[:, 1:] # Skip BOS token
166
-
167
- # Calculate metrics
168
- if sae_acts.numel() > 0:
169
- mean_activation = float(sae_acts.mean())
170
- max_activation = float(sae_acts.max())
171
- else:
172
- mean_activation = 0.0
173
- max_activation = 0.0
174
-
175
- # Record results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  feature_result = {
177
- "name": feature.name,
178
- "category": feature.category,
179
- "activation_score": mean_activation,
180
- "max_activation": max_activation,
181
  "interpretation": self._interpret_activation(
182
- mean_activation, feature
183
  ),
184
  }
185
 
186
  results["features"][feature_id] = feature_result
187
 
188
- if feature.category not in results["categories"]:
189
- results["categories"][feature.category] = []
190
- results["categories"][feature.category].append(feature_result)
191
-
192
- results["recommendations"] = self._generate_recommendations(results)
 
 
 
 
 
 
 
 
 
 
 
193
 
194
  except Exception as e:
195
  logger.error(f"Error analyzing content: {str(e)}")
@@ -197,38 +238,14 @@ class MarketingAnalyzer:
197
 
198
  return results
199
 
200
- def _interpret_activation(
201
- self, activation: float, feature: MarketingFeature
202
- ) -> str:
203
  if activation > 0.8:
204
- return f"Very strong presence of {feature.name.lower()}"
205
  elif activation > 0.5:
206
- return f"Moderate presence of {feature.name.lower()}"
207
  else:
208
- return f"Limited presence of {feature.name.lower()}"
209
-
210
- def _generate_recommendations(self, results: Dict) -> List[str]:
211
- recommendations = []
212
-
213
- try:
214
- tech_features = [
215
- f for f in results["features"].values() if f["category"] == "technical"
216
- ]
217
-
218
- if tech_features:
219
- tech_score = np.mean([f["activation_score"] for f in tech_features])
220
- if tech_score > 0.8:
221
- recommendations.append(
222
- "Consider simplifying technical language for broader audience"
223
- )
224
- elif tech_score < 0.3:
225
- recommendations.append(
226
- "Could benefit from more specific technical details"
227
- )
228
- except Exception as e:
229
- logger.error(f"Error generating recommendations: {str(e)}")
230
-
231
- return recommendations
232
 
233
 
234
  def create_gradio_interface():
 
83
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
84
  torch.set_grad_enabled(False) # Avoid memory issues
85
  self._initialize_model()
 
86
 
87
  def _initialize_model(self):
88
  try:
 
96
  logger.error(f"Error initializing model: {str(e)}")
97
  raise
98
 
99
+ def _load_sae(self, feature_id: int, layer: int = 20):
100
+ """Dynamically load a single SAE"""
101
+ try:
102
+ path = hf_hub_download(
103
+ repo_id="google/gemma-scope-2b-pt-res",
104
+ filename=f"layer_{layer}/width_16k/average_l0_71/params.npz",
105
+ force_download=False,
106
+ )
107
+ params = np.load(path)
 
108
 
109
+ # Create SAE
110
+ d_model = params["W_enc"].shape[0]
111
+ d_sae = params["W_enc"].shape[1]
112
+ sae = JumpReLUSAE(d_model, d_sae).to(self.device)
113
 
114
+ # Load parameters
115
+ sae_params = {
116
+ k: torch.from_numpy(v).to(self.device) for k, v in params.items()
117
+ }
118
+ sae.load_state_dict(sae_params)
119
 
120
+ return sae
121
+ except Exception as e:
122
+ logger.error(f"Error loading SAE for feature {feature_id}: {str(e)}")
123
+ return None
 
 
 
124
 
125
  def _gather_activations(self, text: str, layer: int):
126
  inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
 
138
 
139
  return target_act, inputs
140
 
141
+ def _get_feature_activations(self, text: str, sae, layer: int = 20):
142
+ """Get activations for a single feature"""
143
+ activations, _ = self._gather_activations(text, layer)
144
+ sae_acts = sae.encode(activations.to(torch.float32))
145
+ sae_acts = sae_acts[:, 1:] # Skip BOS token
146
+
147
+ if sae_acts.numel() > 0:
148
+ mean_activation = float(sae_acts.mean())
149
+ max_activation = float(sae_acts.max())
150
+ else:
151
+ mean_activation = 0.0
152
+ max_activation = 0.0
153
+
154
+ return mean_activation, max_activation
155
+
156
  def analyze_content(self, text: str) -> Dict:
157
+ """Analyze content and find most relevant features"""
158
  results = {
159
  "text": text,
160
  "features": {},
 
163
  }
164
 
165
  try:
166
+ # Start with a set of potential features to explore
167
+ feature_pool = list(range(1, 16385)) # Full range of features
168
+ sample_size = 50 # Number of features to sample
169
+ sampled_features = np.random.choice(
170
+ feature_pool, sample_size, replace=False
171
+ )
172
+
173
+ # Test each feature
174
+ feature_activations = []
175
+ for feature_id in sampled_features:
176
+ sae = self._load_sae(feature_id)
177
+ if sae is None:
178
+ continue
179
+
180
+ mean_activation, max_activation = self._get_feature_activations(
181
+ text, sae
182
+ )
183
+ feature_activations.append(
184
+ {
185
+ "feature_id": feature_id,
186
+ "mean_activation": mean_activation,
187
+ "max_activation": max_activation,
188
+ }
189
+ )
190
+
191
+ # Sort by activation and take top features
192
+ top_features = sorted(
193
+ feature_activations, key=lambda x: x["max_activation"], reverse=True
194
+ )[
195
+ :3
196
+ ] # Keep top 3 features
197
+
198
+ # Analyze top features in detail
199
+ for feature_data in top_features:
200
+ feature_id = feature_data["feature_id"]
201
+
202
+ # Get neuronpedia data if available (this would be a placeholder)
203
+ feature_name = f"Feature {feature_id}"
204
+ feature_category = "neural" # Default category
205
+
206
  feature_result = {
207
+ "name": feature_name,
208
+ "category": feature_category,
209
+ "activation_score": feature_data["mean_activation"],
210
+ "max_activation": feature_data["max_activation"],
211
  "interpretation": self._interpret_activation(
212
+ feature_data["mean_activation"], feature_id
213
  ),
214
  }
215
 
216
  results["features"][feature_id] = feature_result
217
 
218
+ if feature_category not in results["categories"]:
219
+ results["categories"][feature_category] = []
220
+ results["categories"][feature_category].append(feature_result)
221
+
222
+ # Generate recommendations based on activations
223
+ if top_features:
224
+ max_activation = max(f["max_activation"] for f in top_features)
225
+ if max_activation > 0.8:
226
+ results["recommendations"].append(
227
+ f"Strong activation detected in feature {top_features[0]['feature_id']}. "
228
+ "Consider exploring this aspect further."
229
+ )
230
+ elif max_activation < 0.3:
231
+ results["recommendations"].append(
232
+ "Low feature activations overall. Content might benefit from more distinctive elements."
233
+ )
234
 
235
  except Exception as e:
236
  logger.error(f"Error analyzing content: {str(e)}")
 
238
 
239
  return results
240
 
241
+ def _interpret_activation(self, activation: float, feature_id: int) -> str:
242
+ """Interpret activation levels for a feature"""
 
243
  if activation > 0.8:
244
+ return f"Very strong activation of feature {feature_id}"
245
  elif activation > 0.5:
246
+ return f"Moderate activation of feature {feature_id}"
247
  else:
248
+ return f"Limited activation of feature {feature_id}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
 
251
  def create_gradio_interface():