Spaces:
Running
Running
update
Browse files
app.py
CHANGED
@@ -83,7 +83,6 @@ class MarketingAnalyzer:
|
|
83 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
84 |
torch.set_grad_enabled(False) # Avoid memory issues
|
85 |
self._initialize_model()
|
86 |
-
self._load_saes()
|
87 |
|
88 |
def _initialize_model(self):
|
89 |
try:
|
@@ -97,35 +96,31 @@ class MarketingAnalyzer:
|
|
97 |
logger.error(f"Error initializing model: {str(e)}")
|
98 |
raise
|
99 |
|
100 |
-
def
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
params = np.load(path)
|
110 |
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
f"Error loading SAE for feature {feature.feature_id}: {str(e)}"
|
127 |
-
)
|
128 |
-
continue
|
129 |
|
130 |
def _gather_activations(self, text: str, layer: int):
|
131 |
inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
|
@@ -143,7 +138,23 @@ class MarketingAnalyzer:
|
|
143 |
|
144 |
return target_act, inputs
|
145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
def analyze_content(self, text: str) -> Dict:
|
|
|
147 |
results = {
|
148 |
"text": text,
|
149 |
"features": {},
|
@@ -152,44 +163,74 @@ class MarketingAnalyzer:
|
|
152 |
}
|
153 |
|
154 |
try:
|
155 |
-
#
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
feature_result = {
|
177 |
-
"name":
|
178 |
-
"category":
|
179 |
-
"activation_score": mean_activation,
|
180 |
-
"max_activation": max_activation,
|
181 |
"interpretation": self._interpret_activation(
|
182 |
-
mean_activation,
|
183 |
),
|
184 |
}
|
185 |
|
186 |
results["features"][feature_id] = feature_result
|
187 |
|
188 |
-
if
|
189 |
-
results["categories"][
|
190 |
-
results["categories"][
|
191 |
-
|
192 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
|
194 |
except Exception as e:
|
195 |
logger.error(f"Error analyzing content: {str(e)}")
|
@@ -197,38 +238,14 @@ class MarketingAnalyzer:
|
|
197 |
|
198 |
return results
|
199 |
|
200 |
-
def _interpret_activation(
|
201 |
-
|
202 |
-
) -> str:
|
203 |
if activation > 0.8:
|
204 |
-
return f"Very strong
|
205 |
elif activation > 0.5:
|
206 |
-
return f"Moderate
|
207 |
else:
|
208 |
-
return f"Limited
|
209 |
-
|
210 |
-
def _generate_recommendations(self, results: Dict) -> List[str]:
|
211 |
-
recommendations = []
|
212 |
-
|
213 |
-
try:
|
214 |
-
tech_features = [
|
215 |
-
f for f in results["features"].values() if f["category"] == "technical"
|
216 |
-
]
|
217 |
-
|
218 |
-
if tech_features:
|
219 |
-
tech_score = np.mean([f["activation_score"] for f in tech_features])
|
220 |
-
if tech_score > 0.8:
|
221 |
-
recommendations.append(
|
222 |
-
"Consider simplifying technical language for broader audience"
|
223 |
-
)
|
224 |
-
elif tech_score < 0.3:
|
225 |
-
recommendations.append(
|
226 |
-
"Could benefit from more specific technical details"
|
227 |
-
)
|
228 |
-
except Exception as e:
|
229 |
-
logger.error(f"Error generating recommendations: {str(e)}")
|
230 |
-
|
231 |
-
return recommendations
|
232 |
|
233 |
|
234 |
def create_gradio_interface():
|
|
|
83 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
84 |
torch.set_grad_enabled(False) # Avoid memory issues
|
85 |
self._initialize_model()
|
|
|
86 |
|
87 |
def _initialize_model(self):
|
88 |
try:
|
|
|
96 |
logger.error(f"Error initializing model: {str(e)}")
|
97 |
raise
|
98 |
|
99 |
+
def _load_sae(self, feature_id: int, layer: int = 20):
|
100 |
+
"""Dynamically load a single SAE"""
|
101 |
+
try:
|
102 |
+
path = hf_hub_download(
|
103 |
+
repo_id="google/gemma-scope-2b-pt-res",
|
104 |
+
filename=f"layer_{layer}/width_16k/average_l0_71/params.npz",
|
105 |
+
force_download=False,
|
106 |
+
)
|
107 |
+
params = np.load(path)
|
|
|
108 |
|
109 |
+
# Create SAE
|
110 |
+
d_model = params["W_enc"].shape[0]
|
111 |
+
d_sae = params["W_enc"].shape[1]
|
112 |
+
sae = JumpReLUSAE(d_model, d_sae).to(self.device)
|
113 |
|
114 |
+
# Load parameters
|
115 |
+
sae_params = {
|
116 |
+
k: torch.from_numpy(v).to(self.device) for k, v in params.items()
|
117 |
+
}
|
118 |
+
sae.load_state_dict(sae_params)
|
119 |
|
120 |
+
return sae
|
121 |
+
except Exception as e:
|
122 |
+
logger.error(f"Error loading SAE for feature {feature_id}: {str(e)}")
|
123 |
+
return None
|
|
|
|
|
|
|
124 |
|
125 |
def _gather_activations(self, text: str, layer: int):
|
126 |
inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
|
|
|
138 |
|
139 |
return target_act, inputs
|
140 |
|
141 |
+
def _get_feature_activations(self, text: str, sae, layer: int = 20):
|
142 |
+
"""Get activations for a single feature"""
|
143 |
+
activations, _ = self._gather_activations(text, layer)
|
144 |
+
sae_acts = sae.encode(activations.to(torch.float32))
|
145 |
+
sae_acts = sae_acts[:, 1:] # Skip BOS token
|
146 |
+
|
147 |
+
if sae_acts.numel() > 0:
|
148 |
+
mean_activation = float(sae_acts.mean())
|
149 |
+
max_activation = float(sae_acts.max())
|
150 |
+
else:
|
151 |
+
mean_activation = 0.0
|
152 |
+
max_activation = 0.0
|
153 |
+
|
154 |
+
return mean_activation, max_activation
|
155 |
+
|
156 |
def analyze_content(self, text: str) -> Dict:
|
157 |
+
"""Analyze content and find most relevant features"""
|
158 |
results = {
|
159 |
"text": text,
|
160 |
"features": {},
|
|
|
163 |
}
|
164 |
|
165 |
try:
|
166 |
+
# Start with a set of potential features to explore
|
167 |
+
feature_pool = list(range(1, 16385)) # Full range of features
|
168 |
+
sample_size = 50 # Number of features to sample
|
169 |
+
sampled_features = np.random.choice(
|
170 |
+
feature_pool, sample_size, replace=False
|
171 |
+
)
|
172 |
+
|
173 |
+
# Test each feature
|
174 |
+
feature_activations = []
|
175 |
+
for feature_id in sampled_features:
|
176 |
+
sae = self._load_sae(feature_id)
|
177 |
+
if sae is None:
|
178 |
+
continue
|
179 |
+
|
180 |
+
mean_activation, max_activation = self._get_feature_activations(
|
181 |
+
text, sae
|
182 |
+
)
|
183 |
+
feature_activations.append(
|
184 |
+
{
|
185 |
+
"feature_id": feature_id,
|
186 |
+
"mean_activation": mean_activation,
|
187 |
+
"max_activation": max_activation,
|
188 |
+
}
|
189 |
+
)
|
190 |
+
|
191 |
+
# Sort by activation and take top features
|
192 |
+
top_features = sorted(
|
193 |
+
feature_activations, key=lambda x: x["max_activation"], reverse=True
|
194 |
+
)[
|
195 |
+
:3
|
196 |
+
] # Keep top 3 features
|
197 |
+
|
198 |
+
# Analyze top features in detail
|
199 |
+
for feature_data in top_features:
|
200 |
+
feature_id = feature_data["feature_id"]
|
201 |
+
|
202 |
+
# Get neuronpedia data if available (this would be a placeholder)
|
203 |
+
feature_name = f"Feature {feature_id}"
|
204 |
+
feature_category = "neural" # Default category
|
205 |
+
|
206 |
feature_result = {
|
207 |
+
"name": feature_name,
|
208 |
+
"category": feature_category,
|
209 |
+
"activation_score": feature_data["mean_activation"],
|
210 |
+
"max_activation": feature_data["max_activation"],
|
211 |
"interpretation": self._interpret_activation(
|
212 |
+
feature_data["mean_activation"], feature_id
|
213 |
),
|
214 |
}
|
215 |
|
216 |
results["features"][feature_id] = feature_result
|
217 |
|
218 |
+
if feature_category not in results["categories"]:
|
219 |
+
results["categories"][feature_category] = []
|
220 |
+
results["categories"][feature_category].append(feature_result)
|
221 |
+
|
222 |
+
# Generate recommendations based on activations
|
223 |
+
if top_features:
|
224 |
+
max_activation = max(f["max_activation"] for f in top_features)
|
225 |
+
if max_activation > 0.8:
|
226 |
+
results["recommendations"].append(
|
227 |
+
f"Strong activation detected in feature {top_features[0]['feature_id']}. "
|
228 |
+
"Consider exploring this aspect further."
|
229 |
+
)
|
230 |
+
elif max_activation < 0.3:
|
231 |
+
results["recommendations"].append(
|
232 |
+
"Low feature activations overall. Content might benefit from more distinctive elements."
|
233 |
+
)
|
234 |
|
235 |
except Exception as e:
|
236 |
logger.error(f"Error analyzing content: {str(e)}")
|
|
|
238 |
|
239 |
return results
|
240 |
|
241 |
+
def _interpret_activation(self, activation: float, feature_id: int) -> str:
|
242 |
+
"""Interpret activation levels for a feature"""
|
|
|
243 |
if activation > 0.8:
|
244 |
+
return f"Very strong activation of feature {feature_id}"
|
245 |
elif activation > 0.5:
|
246 |
+
return f"Moderate activation of feature {feature_id}"
|
247 |
else:
|
248 |
+
return f"Limited activation of feature {feature_id}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
|
250 |
|
251 |
def create_gradio_interface():
|