cyberandy commited on
Commit
9186441
·
verified ·
1 Parent(s): deaf693

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -50
app.py CHANGED
@@ -48,40 +48,31 @@ MARKETING_FEATURES = [
48
  interpretation_guide="High activation suggests strong SEO potential",
49
  layer=20
50
  ),
51
- # Add more relevant features as we discover them
52
  ]
53
 
54
  class MarketingAnalyzer:
55
  """Main class for analyzing marketing content using Gemma Scope"""
56
 
57
- def __init__(self, model_size: str = "2b"):
58
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59
- self._initialize_model(model_size)
 
 
60
  self._load_saes()
61
 
62
- def _initialize_model(self, model_size: str):
63
  """Initialize Gemma model and tokenizer"""
64
  try:
65
- import os
66
- model_name = f"google/gemma-{model_size}"
67
 
68
- # Access HF token from environment variable
69
- hf_token = os.environ.get('HF_TOKEN')
70
- if not hf_token:
71
- logger.warning("HF_TOKEN not found in environment variables")
72
-
73
- # Initialize model and tokenizer with token
74
  self.model = AutoModelForCausalLM.from_pretrained(
75
  model_name,
76
- token=hf_token,
77
- device_map='auto' # Automatically handle device placement
78
- )
79
- self.tokenizer = AutoTokenizer.from_pretrained(
80
- model_name,
81
- token=hf_token
82
  )
 
83
 
84
- self.model.eval() # Set to evaluation mode
85
  logger.info(f"Initialized model: {model_name}")
86
 
87
  except Exception as e:
@@ -100,7 +91,7 @@ class MarketingAnalyzer:
100
  )
101
  params = np.load(path)
102
  self.saes[feature.feature_id] = {
103
- 'params': {k: torch.from_numpy(v).cuda() for k, v in params.items()},
104
  'feature': feature
105
  }
106
  logger.info(f"Loaded SAE for feature {feature.feature_id}")
@@ -135,14 +126,23 @@ class MarketingAnalyzer:
135
  feature.threshold
136
  )
137
 
 
 
 
 
 
 
 
 
 
138
  # Record results
139
  feature_result = {
140
  'name': feature.name,
141
  'category': feature.category,
142
- 'activation_score': float(activations.mean()),
143
- 'max_activation': float(activations.max()),
144
  'interpretation': self._interpret_activation(
145
- activations,
146
  feature
147
  )
148
  }
@@ -177,14 +177,13 @@ class MarketingAnalyzer:
177
 
178
  def _interpret_activation(
179
  self,
180
- activations: torch.Tensor,
181
  feature: MarketingFeature
182
  ) -> str:
183
  """Interpret activation patterns for a feature"""
184
- mean_activation = float(activations.mean())
185
- if mean_activation > 0.8:
186
  return f"Very strong presence of {feature.name.lower()}"
187
- elif mean_activation > 0.5:
188
  return f"Moderate presence of {feature.name.lower()}"
189
  else:
190
  return f"Limited presence of {feature.name.lower()}"
@@ -193,21 +192,28 @@ class MarketingAnalyzer:
193
  """Generate content recommendations based on analysis"""
194
  recommendations = []
195
 
196
- # Analyze technical complexity
197
- tech_score = np.mean([
198
- f['activation_score'] for f in results['features'].values()
199
- if f['category'] == 'technical'
200
- ])
201
- if tech_score > 0.8:
202
- recommendations.append(
203
- "Consider simplifying technical language for broader audience"
204
- )
205
- elif tech_score < 0.3:
206
- recommendations.append(
207
- "Could benefit from more specific technical details"
208
- )
 
 
 
 
 
 
 
 
209
 
210
- # Add more recommendation logic as needed
211
  return recommendations
212
 
213
  def create_gradio_interface():
@@ -216,7 +222,6 @@ def create_gradio_interface():
216
  analyzer = MarketingAnalyzer()
217
  except Exception as e:
218
  logger.error(f"Failed to initialize analyzer: {str(e)}")
219
- # Provide a more graceful fallback or error message in the interface
220
  return gr.Interface(
221
  fn=lambda x: "Error: Failed to initialize model. Please check authentication.",
222
  inputs=gr.Textbox(),
@@ -234,8 +239,9 @@ def create_gradio_interface():
234
  # Overall category scores
235
  output += "Category Scores:\n"
236
  for category, features in results['categories'].items():
237
- avg_score = np.mean([f['activation_score'] for f in features])
238
- output += f"{category.title()}: {avg_score:.2f}\n"
 
239
 
240
  # Feature details
241
  output += "\nFeature Details:\n"
@@ -245,13 +251,15 @@ def create_gradio_interface():
245
  output += f"Interpretation: {feature['interpretation']}\n"
246
 
247
  # Recommendations
248
- output += "\nRecommendations:\n"
249
- for rec in results['recommendations']:
250
- output += f"- {rec}\n"
 
251
 
252
  return output
253
 
254
- iface = gr.Interface(
 
255
  fn=analyze,
256
  inputs=gr.Textbox(
257
  lines=5,
@@ -264,10 +272,16 @@ def create_gradio_interface():
264
  ["WordLift is an AI-powered SEO tool"],
265
  ["Our advanced machine learning algorithms optimize your content"],
266
  ["Simple and effective website optimization"]
267
- ]
 
 
 
 
 
 
268
  )
269
 
270
- return iface
271
 
272
  if __name__ == "__main__":
273
  iface = create_gradio_interface()
 
48
  interpretation_guide="High activation suggests strong SEO potential",
49
  layer=20
50
  ),
 
51
  ]
52
 
53
  class MarketingAnalyzer:
54
  """Main class for analyzing marketing content using Gemma Scope"""
55
 
56
+ def __init__(self):
57
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
+ # Store model size as instance variable
59
+ self.model_size = "2b"
60
+ self._initialize_model()
61
  self._load_saes()
62
 
63
+ def _initialize_model(self):
64
  """Initialize Gemma model and tokenizer"""
65
  try:
66
+ model_name = f"google/gemma-{self.model_size}"
 
67
 
68
+ # Initialize model and tokenizer with token from environment
 
 
 
 
 
69
  self.model = AutoModelForCausalLM.from_pretrained(
70
  model_name,
71
+ device_map='auto'
 
 
 
 
 
72
  )
73
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
74
 
75
+ self.model.eval()
76
  logger.info(f"Initialized model: {model_name}")
77
 
78
  except Exception as e:
 
91
  )
92
  params = np.load(path)
93
  self.saes[feature.feature_id] = {
94
+ 'params': {k: torch.from_numpy(v).to(self.device) for k, v in params.items()},
95
  'feature': feature
96
  }
97
  logger.info(f"Loaded SAE for feature {feature.feature_id}")
 
126
  feature.threshold
127
  )
128
 
129
+ # Skip BOS token and handle empty activations
130
+ activations = activations[:, 1:] # Skip BOS token
131
+ if activations.numel() > 0:
132
+ mean_activation = float(activations.mean())
133
+ max_activation = float(activations.max())
134
+ else:
135
+ mean_activation = 0.0
136
+ max_activation = 0.0
137
+
138
  # Record results
139
  feature_result = {
140
  'name': feature.name,
141
  'category': feature.category,
142
+ 'activation_score': mean_activation,
143
+ 'max_activation': max_activation,
144
  'interpretation': self._interpret_activation(
145
+ mean_activation,
146
  feature
147
  )
148
  }
 
177
 
178
  def _interpret_activation(
179
  self,
180
+ activation: float,
181
  feature: MarketingFeature
182
  ) -> str:
183
  """Interpret activation patterns for a feature"""
184
+ if activation > 0.8:
 
185
  return f"Very strong presence of {feature.name.lower()}"
186
+ elif activation > 0.5:
187
  return f"Moderate presence of {feature.name.lower()}"
188
  else:
189
  return f"Limited presence of {feature.name.lower()}"
 
192
  """Generate content recommendations based on analysis"""
193
  recommendations = []
194
 
195
+ try:
196
+ # Get technical features
197
+ tech_features = [
198
+ f for f in results['features'].values()
199
+ if f['category'] == 'technical'
200
+ ]
201
+
202
+ # Calculate average technical score if we have features
203
+ if tech_features:
204
+ tech_score = np.mean([f['activation_score'] for f in tech_features])
205
+
206
+ if tech_score > 0.8:
207
+ recommendations.append(
208
+ "Consider simplifying technical language for broader audience"
209
+ )
210
+ elif tech_score < 0.3:
211
+ recommendations.append(
212
+ "Could benefit from more specific technical details"
213
+ )
214
+ except Exception as e:
215
+ logger.error(f"Error generating recommendations: {str(e)}")
216
 
 
217
  return recommendations
218
 
219
  def create_gradio_interface():
 
222
  analyzer = MarketingAnalyzer()
223
  except Exception as e:
224
  logger.error(f"Failed to initialize analyzer: {str(e)}")
 
225
  return gr.Interface(
226
  fn=lambda x: "Error: Failed to initialize model. Please check authentication.",
227
  inputs=gr.Textbox(),
 
239
  # Overall category scores
240
  output += "Category Scores:\n"
241
  for category, features in results['categories'].items():
242
+ if features: # Check if we have features for this category
243
+ avg_score = np.mean([f['activation_score'] for f in features])
244
+ output += f"{category.title()}: {avg_score:.2f}\n"
245
 
246
  # Feature details
247
  output += "\nFeature Details:\n"
 
251
  output += f"Interpretation: {feature['interpretation']}\n"
252
 
253
  # Recommendations
254
+ if results['recommendations']:
255
+ output += "\nRecommendations:\n"
256
+ for rec in results['recommendations']:
257
+ output += f"- {rec}\n"
258
 
259
  return output
260
 
261
+ # Create interface with custom styling
262
+ interface = gr.Interface(
263
  fn=analyze,
264
  inputs=gr.Textbox(
265
  lines=5,
 
272
  ["WordLift is an AI-powered SEO tool"],
273
  ["Our advanced machine learning algorithms optimize your content"],
274
  ["Simple and effective website optimization"]
275
+ ],
276
+ theme=gr.themes.Default().set(
277
+ button_primary_background_color="#3452db",
278
+ button_primary_text_color="white",
279
+ button_secondary_background_color="#f5f5f5",
280
+ button_secondary_text_color="#3452db",
281
+ )
282
  )
283
 
284
+ return interface
285
 
286
  if __name__ == "__main__":
287
  iface = create_gradio_interface()