Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -48,40 +48,31 @@ MARKETING_FEATURES = [
|
|
48 |
interpretation_guide="High activation suggests strong SEO potential",
|
49 |
layer=20
|
50 |
),
|
51 |
-
# Add more relevant features as we discover them
|
52 |
]
|
53 |
|
54 |
class MarketingAnalyzer:
|
55 |
"""Main class for analyzing marketing content using Gemma Scope"""
|
56 |
|
57 |
-
def __init__(self
|
58 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
59 |
-
|
|
|
|
|
60 |
self._load_saes()
|
61 |
|
62 |
-
def _initialize_model(self
|
63 |
"""Initialize Gemma model and tokenizer"""
|
64 |
try:
|
65 |
-
|
66 |
-
model_name = f"google/gemma-{model_size}"
|
67 |
|
68 |
-
#
|
69 |
-
hf_token = os.environ.get('HF_TOKEN')
|
70 |
-
if not hf_token:
|
71 |
-
logger.warning("HF_TOKEN not found in environment variables")
|
72 |
-
|
73 |
-
# Initialize model and tokenizer with token
|
74 |
self.model = AutoModelForCausalLM.from_pretrained(
|
75 |
model_name,
|
76 |
-
|
77 |
-
device_map='auto' # Automatically handle device placement
|
78 |
-
)
|
79 |
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
80 |
-
model_name,
|
81 |
-
token=hf_token
|
82 |
)
|
|
|
83 |
|
84 |
-
self.model.eval()
|
85 |
logger.info(f"Initialized model: {model_name}")
|
86 |
|
87 |
except Exception as e:
|
@@ -100,7 +91,7 @@ class MarketingAnalyzer:
|
|
100 |
)
|
101 |
params = np.load(path)
|
102 |
self.saes[feature.feature_id] = {
|
103 |
-
'params': {k: torch.from_numpy(v).
|
104 |
'feature': feature
|
105 |
}
|
106 |
logger.info(f"Loaded SAE for feature {feature.feature_id}")
|
@@ -135,14 +126,23 @@ class MarketingAnalyzer:
|
|
135 |
feature.threshold
|
136 |
)
|
137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
# Record results
|
139 |
feature_result = {
|
140 |
'name': feature.name,
|
141 |
'category': feature.category,
|
142 |
-
'activation_score':
|
143 |
-
'max_activation':
|
144 |
'interpretation': self._interpret_activation(
|
145 |
-
|
146 |
feature
|
147 |
)
|
148 |
}
|
@@ -177,14 +177,13 @@ class MarketingAnalyzer:
|
|
177 |
|
178 |
def _interpret_activation(
|
179 |
self,
|
180 |
-
|
181 |
feature: MarketingFeature
|
182 |
) -> str:
|
183 |
"""Interpret activation patterns for a feature"""
|
184 |
-
|
185 |
-
if mean_activation > 0.8:
|
186 |
return f"Very strong presence of {feature.name.lower()}"
|
187 |
-
elif
|
188 |
return f"Moderate presence of {feature.name.lower()}"
|
189 |
else:
|
190 |
return f"Limited presence of {feature.name.lower()}"
|
@@ -193,21 +192,28 @@ class MarketingAnalyzer:
|
|
193 |
"""Generate content recommendations based on analysis"""
|
194 |
recommendations = []
|
195 |
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
|
210 |
-
# Add more recommendation logic as needed
|
211 |
return recommendations
|
212 |
|
213 |
def create_gradio_interface():
|
@@ -216,7 +222,6 @@ def create_gradio_interface():
|
|
216 |
analyzer = MarketingAnalyzer()
|
217 |
except Exception as e:
|
218 |
logger.error(f"Failed to initialize analyzer: {str(e)}")
|
219 |
-
# Provide a more graceful fallback or error message in the interface
|
220 |
return gr.Interface(
|
221 |
fn=lambda x: "Error: Failed to initialize model. Please check authentication.",
|
222 |
inputs=gr.Textbox(),
|
@@ -234,8 +239,9 @@ def create_gradio_interface():
|
|
234 |
# Overall category scores
|
235 |
output += "Category Scores:\n"
|
236 |
for category, features in results['categories'].items():
|
237 |
-
|
238 |
-
|
|
|
239 |
|
240 |
# Feature details
|
241 |
output += "\nFeature Details:\n"
|
@@ -245,13 +251,15 @@ def create_gradio_interface():
|
|
245 |
output += f"Interpretation: {feature['interpretation']}\n"
|
246 |
|
247 |
# Recommendations
|
248 |
-
|
249 |
-
|
250 |
-
|
|
|
251 |
|
252 |
return output
|
253 |
|
254 |
-
|
|
|
255 |
fn=analyze,
|
256 |
inputs=gr.Textbox(
|
257 |
lines=5,
|
@@ -264,10 +272,16 @@ def create_gradio_interface():
|
|
264 |
["WordLift is an AI-powered SEO tool"],
|
265 |
["Our advanced machine learning algorithms optimize your content"],
|
266 |
["Simple and effective website optimization"]
|
267 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
268 |
)
|
269 |
|
270 |
-
return
|
271 |
|
272 |
if __name__ == "__main__":
|
273 |
iface = create_gradio_interface()
|
|
|
48 |
interpretation_guide="High activation suggests strong SEO potential",
|
49 |
layer=20
|
50 |
),
|
|
|
51 |
]
|
52 |
|
53 |
class MarketingAnalyzer:
|
54 |
"""Main class for analyzing marketing content using Gemma Scope"""
|
55 |
|
56 |
+
def __init__(self):
|
57 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
58 |
+
# Store model size as instance variable
|
59 |
+
self.model_size = "2b"
|
60 |
+
self._initialize_model()
|
61 |
self._load_saes()
|
62 |
|
63 |
+
def _initialize_model(self):
|
64 |
"""Initialize Gemma model and tokenizer"""
|
65 |
try:
|
66 |
+
model_name = f"google/gemma-{self.model_size}"
|
|
|
67 |
|
68 |
+
# Initialize model and tokenizer with token from environment
|
|
|
|
|
|
|
|
|
|
|
69 |
self.model = AutoModelForCausalLM.from_pretrained(
|
70 |
model_name,
|
71 |
+
device_map='auto'
|
|
|
|
|
|
|
|
|
|
|
72 |
)
|
73 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
74 |
|
75 |
+
self.model.eval()
|
76 |
logger.info(f"Initialized model: {model_name}")
|
77 |
|
78 |
except Exception as e:
|
|
|
91 |
)
|
92 |
params = np.load(path)
|
93 |
self.saes[feature.feature_id] = {
|
94 |
+
'params': {k: torch.from_numpy(v).to(self.device) for k, v in params.items()},
|
95 |
'feature': feature
|
96 |
}
|
97 |
logger.info(f"Loaded SAE for feature {feature.feature_id}")
|
|
|
126 |
feature.threshold
|
127 |
)
|
128 |
|
129 |
+
# Skip BOS token and handle empty activations
|
130 |
+
activations = activations[:, 1:] # Skip BOS token
|
131 |
+
if activations.numel() > 0:
|
132 |
+
mean_activation = float(activations.mean())
|
133 |
+
max_activation = float(activations.max())
|
134 |
+
else:
|
135 |
+
mean_activation = 0.0
|
136 |
+
max_activation = 0.0
|
137 |
+
|
138 |
# Record results
|
139 |
feature_result = {
|
140 |
'name': feature.name,
|
141 |
'category': feature.category,
|
142 |
+
'activation_score': mean_activation,
|
143 |
+
'max_activation': max_activation,
|
144 |
'interpretation': self._interpret_activation(
|
145 |
+
mean_activation,
|
146 |
feature
|
147 |
)
|
148 |
}
|
|
|
177 |
|
178 |
def _interpret_activation(
|
179 |
self,
|
180 |
+
activation: float,
|
181 |
feature: MarketingFeature
|
182 |
) -> str:
|
183 |
"""Interpret activation patterns for a feature"""
|
184 |
+
if activation > 0.8:
|
|
|
185 |
return f"Very strong presence of {feature.name.lower()}"
|
186 |
+
elif activation > 0.5:
|
187 |
return f"Moderate presence of {feature.name.lower()}"
|
188 |
else:
|
189 |
return f"Limited presence of {feature.name.lower()}"
|
|
|
192 |
"""Generate content recommendations based on analysis"""
|
193 |
recommendations = []
|
194 |
|
195 |
+
try:
|
196 |
+
# Get technical features
|
197 |
+
tech_features = [
|
198 |
+
f for f in results['features'].values()
|
199 |
+
if f['category'] == 'technical'
|
200 |
+
]
|
201 |
+
|
202 |
+
# Calculate average technical score if we have features
|
203 |
+
if tech_features:
|
204 |
+
tech_score = np.mean([f['activation_score'] for f in tech_features])
|
205 |
+
|
206 |
+
if tech_score > 0.8:
|
207 |
+
recommendations.append(
|
208 |
+
"Consider simplifying technical language for broader audience"
|
209 |
+
)
|
210 |
+
elif tech_score < 0.3:
|
211 |
+
recommendations.append(
|
212 |
+
"Could benefit from more specific technical details"
|
213 |
+
)
|
214 |
+
except Exception as e:
|
215 |
+
logger.error(f"Error generating recommendations: {str(e)}")
|
216 |
|
|
|
217 |
return recommendations
|
218 |
|
219 |
def create_gradio_interface():
|
|
|
222 |
analyzer = MarketingAnalyzer()
|
223 |
except Exception as e:
|
224 |
logger.error(f"Failed to initialize analyzer: {str(e)}")
|
|
|
225 |
return gr.Interface(
|
226 |
fn=lambda x: "Error: Failed to initialize model. Please check authentication.",
|
227 |
inputs=gr.Textbox(),
|
|
|
239 |
# Overall category scores
|
240 |
output += "Category Scores:\n"
|
241 |
for category, features in results['categories'].items():
|
242 |
+
if features: # Check if we have features for this category
|
243 |
+
avg_score = np.mean([f['activation_score'] for f in features])
|
244 |
+
output += f"{category.title()}: {avg_score:.2f}\n"
|
245 |
|
246 |
# Feature details
|
247 |
output += "\nFeature Details:\n"
|
|
|
251 |
output += f"Interpretation: {feature['interpretation']}\n"
|
252 |
|
253 |
# Recommendations
|
254 |
+
if results['recommendations']:
|
255 |
+
output += "\nRecommendations:\n"
|
256 |
+
for rec in results['recommendations']:
|
257 |
+
output += f"- {rec}\n"
|
258 |
|
259 |
return output
|
260 |
|
261 |
+
# Create interface with custom styling
|
262 |
+
interface = gr.Interface(
|
263 |
fn=analyze,
|
264 |
inputs=gr.Textbox(
|
265 |
lines=5,
|
|
|
272 |
["WordLift is an AI-powered SEO tool"],
|
273 |
["Our advanced machine learning algorithms optimize your content"],
|
274 |
["Simple and effective website optimization"]
|
275 |
+
],
|
276 |
+
theme=gr.themes.Default().set(
|
277 |
+
button_primary_background_color="#3452db",
|
278 |
+
button_primary_text_color="white",
|
279 |
+
button_secondary_background_color="#f5f5f5",
|
280 |
+
button_secondary_text_color="#3452db",
|
281 |
+
)
|
282 |
)
|
283 |
|
284 |
+
return interface
|
285 |
|
286 |
if __name__ == "__main__":
|
287 |
iface = create_gradio_interface()
|