File size: 12,031 Bytes
f85532f
 
7e6371a
f85532f
 
 
 
 
 
 
 
 
 
94ca202
f85532f
 
 
94ca202
f85532f
 
 
 
 
 
 
 
94ca202
7e6371a
f85532f
 
 
 
 
 
 
7e6371a
f85532f
 
 
 
 
 
 
7e6371a
f85532f
 
 
 
 
 
 
7e6371a
f85532f
 
 
94ca202
7e6371a
 
 
 
 
 
 
 
 
 
 
 
 
 
574ab91
7e6371a
 
 
 
 
 
 
 
 
 
9186441
f85532f
7e6371a
9186441
f85532f
 
9186441
f85532f
deaf693
7e6371a
deaf693
7e6371a
9186441
7e6371a
f85532f
 
 
 
 
 
 
 
 
7e6371a
94ca202
7e6371a
f85532f
 
7e6371a
 
 
 
 
 
 
 
 
f85532f
7e6371a
 
 
f85532f
 
94ca202
 
 
f85532f
 
7e6371a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f85532f
 
94ca202
 
 
 
f85532f
574ab91
f85532f
7e6371a
f85532f
94ca202
7e6371a
574ab91
7e6371a
 
 
 
 
 
574ab91
7e6371a
 
 
 
9186441
 
 
574ab91
f85532f
 
94ca202
 
 
 
 
 
 
f85532f
574ab91
94ca202
574ab91
94ca202
 
 
574ab91
94ca202
574ab91
f85532f
 
 
574ab91
f85532f
 
 
94ca202
f85532f
9186441
f85532f
9186441
f85532f
 
 
 
 
 
574ab91
9186441
 
94ca202
9186441
574ab91
9186441
94ca202
9186441
 
 
 
 
 
 
 
 
 
574ab91
f85532f
 
94ca202
f85532f
deaf693
 
 
 
 
 
 
 
 
7e6371a
deaf693
574ab91
f85532f
 
574ab91
ad0839d
574ab91
ad0839d
94ca202
7e6371a
94ca202
ad0839d
574ab91
ad0839d
94ca202
ad0839d
 
 
 
 
574ab91
94ca202
ad0839d
94ca202
9186441
574ab91
ad0839d
 
 
 
 
 
 
 
574ab91
7e6371a
 
 
 
 
 
 
 
 
 
 
 
 
 
ad0839d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e6371a
ad0839d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e6371a
ad0839d
 
7e6371a
ad0839d
7e6371a
574ab91
9186441
f85532f
94ca202
f85532f
 
94ca202
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
import gradio as gr
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download
import numpy as np
from dataclasses import dataclass
from typing import List, Dict, Optional
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


@dataclass
class MarketingFeature:
    """Structure to hold marketing-relevant feature information"""

    feature_id: int
    name: str
    category: str
    description: str
    interpretation_guide: str
    layer: int
    threshold: float = 0.1


# Define relevant features
MARKETING_FEATURES = [
    MarketingFeature(
        feature_id=35,
        name="Technical Term Detector",
        category="technical",
        description="Detects technical and specialized terminology",
        interpretation_guide="High activation indicates strong technical focus",
        layer=20,
    ),
    MarketingFeature(
        feature_id=6680,
        name="Compound Technical Terms",
        category="technical",
        description="Identifies complex technical concepts",
        interpretation_guide="Consider simplifying language if activation is too high",
        layer=20,
    ),
    MarketingFeature(
        feature_id=2,
        name="SEO Keyword Detector",
        category="seo",
        description="Identifies potential SEO keywords",
        interpretation_guide="High activation suggests strong SEO potential",
        layer=20,
    ),
]


class JumpReLUSAE(nn.Module):
    def __init__(self, d_model, d_sae):
        super().__init__()
        self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))
        self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))
        self.threshold = nn.Parameter(torch.zeros(d_sae))
        self.b_enc = nn.Parameter(torch.zeros(d_sae))
        self.b_dec = nn.Parameter(torch.zeros(d_model))

    def encode(self, input_acts):
        pre_acts = input_acts @ self.W_enc + self.b_enc
        mask = pre_acts > self.threshold
        acts = mask * torch.nn.functional.relu(pre_acts)
        return acts

    def decode(self, acts):
        return acts @ self.W_dec + self.b_dec

    def forward(self, acts):
        acts = self.encode(acts)
        recon = self.decode(acts)
        return recon


class MarketingAnalyzer:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        torch.set_grad_enabled(False)  # Avoid memory issues
        self._initialize_model()
        self._load_saes()

    def _initialize_model(self):
        try:
            self.model = AutoModelForCausalLM.from_pretrained(
                "google/gemma-2-2b", device_map="auto"
            )
            self.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
            self.model.eval()
            logger.info("Model initialized successfully")
        except Exception as e:
            logger.error(f"Error initializing model: {str(e)}")
            raise

    def _load_saes(self):
        self.saes = {}
        for feature in MARKETING_FEATURES:
            try:
                path = hf_hub_download(
                    repo_id="google/gemma-scope-2b-pt-res",
                    filename=f"layer_{feature.layer}/width_16k/average_l0_71/params.npz",
                    force_download=False,
                )
                params = np.load(path)

                # Create SAE
                d_model = params["W_enc"].shape[0]
                d_sae = params["W_enc"].shape[1]
                sae = JumpReLUSAE(d_model, d_sae).to(self.device)

                # Load parameters
                sae_params = {
                    k: torch.from_numpy(v).to(self.device) for k, v in params.items()
                }
                sae.load_state_dict(sae_params)

                self.saes[feature.feature_id] = {"sae": sae, "feature": feature}
                logger.info(f"Loaded SAE for feature {feature.feature_id}")
            except Exception as e:
                logger.error(
                    f"Error loading SAE for feature {feature.feature_id}: {str(e)}"
                )
                continue

    def _gather_activations(self, text: str, layer: int):
        inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
        target_act = None

        def hook(mod, inputs, outputs):
            nonlocal target_act
            target_act = outputs[0]
            return outputs

        handle = self.model.model.layers[layer].register_forward_hook(hook)
        with torch.no_grad():
            _ = self.model(**inputs)
        handle.remove()

        return target_act, inputs

    def analyze_content(self, text: str) -> Dict:
        results = {
            "text": text,
            "features": {},
            "categories": {},
            "recommendations": [],
        }

        try:
            # Get activations for each feature
            for feature_id, sae_data in self.saes.items():
                feature = sae_data["feature"]
                sae = sae_data["sae"]

                # Get layer activations
                activations, inputs = self._gather_activations(text, feature.layer)

                # Skip BOS token and get activations
                sae_acts = sae.encode(activations.to(torch.float32))
                sae_acts = sae_acts[:, 1:]  # Skip BOS token

                # Calculate metrics
                if sae_acts.numel() > 0:
                    mean_activation = float(sae_acts.mean())
                    max_activation = float(sae_acts.max())
                else:
                    mean_activation = 0.0
                    max_activation = 0.0

                # Record results
                feature_result = {
                    "name": feature.name,
                    "category": feature.category,
                    "activation_score": mean_activation,
                    "max_activation": max_activation,
                    "interpretation": self._interpret_activation(
                        mean_activation, feature
                    ),
                }

                results["features"][feature_id] = feature_result

                if feature.category not in results["categories"]:
                    results["categories"][feature.category] = []
                results["categories"][feature.category].append(feature_result)

            results["recommendations"] = self._generate_recommendations(results)

        except Exception as e:
            logger.error(f"Error analyzing content: {str(e)}")
            raise

        return results

    def _interpret_activation(
        self, activation: float, feature: MarketingFeature
    ) -> str:
        if activation > 0.8:
            return f"Very strong presence of {feature.name.lower()}"
        elif activation > 0.5:
            return f"Moderate presence of {feature.name.lower()}"
        else:
            return f"Limited presence of {feature.name.lower()}"

    def _generate_recommendations(self, results: Dict) -> List[str]:
        recommendations = []

        try:
            tech_features = [
                f for f in results["features"].values() if f["category"] == "technical"
            ]

            if tech_features:
                tech_score = np.mean([f["activation_score"] for f in tech_features])
                if tech_score > 0.8:
                    recommendations.append(
                        "Consider simplifying technical language for broader audience"
                    )
                elif tech_score < 0.3:
                    recommendations.append(
                        "Could benefit from more specific technical details"
                    )
        except Exception as e:
            logger.error(f"Error generating recommendations: {str(e)}")

        return recommendations


def create_gradio_interface():
    try:
        analyzer = MarketingAnalyzer()
    except Exception as e:
        logger.error(f"Failed to initialize analyzer: {str(e)}")
        return gr.Interface(
            fn=lambda x: "Error: Failed to initialize model. Please check authentication.",
            inputs=gr.Textbox(),
            outputs=gr.Textbox(),
            title="Marketing Content Analyzer (Error)",
            description="Failed to initialize.",
        )

    def analyze(text):
        results = analyzer.analyze_content(text)

        output = "# Content Analysis Results\n\n"

        output += "## Category Scores\n"
        for category, features in results["categories"].items():
            if features:
                avg_score = np.mean([f["activation_score"] for f in features])
                output += f"**{category.title()}**: {avg_score:.2f}\n"

        output += "\n## Feature Details\n"
        for feature_id, feature in results["features"].items():
            output += f"\n### {feature['name']} (Feature {feature_id})\n"
            output += f"**Score**: {feature['activation_score']:.2f}\n\n"
            output += f"**Interpretation**: {feature['interpretation']}\n\n"
            # Add feature explanation from Neuronpedia reference
            output += f"[View feature details on Neuronpedia](https://neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{feature_id})\n\n"

        if results["recommendations"]:
            output += "\n## Recommendations\n"
            for rec in results["recommendations"]:
                output += f"- {rec}\n"

        feature_id = max(
            results["features"].items(), key=lambda x: x[1]["activation_score"]
        )[0]

        # Build dashboard URL for the highest activating feature
        dashboard_url = f"https://neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{feature_id}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

        return output, dashboard_url, feature_id

    with gr.Blocks(
        theme=gr.themes.Default(
            font=[gr.themes.GoogleFont("Open Sans"), "Arial", "sans-serif"],
            primary_hue="indigo",
            secondary_hue="blue",
            neutral_hue="gray",
        )
    ) as interface:
        gr.Markdown("# Marketing Content Analyzer")
        gr.Markdown(
            "Analyze your marketing content using Gemma Scope's neural features"
        )

        with gr.Row():
            with gr.Column(scale=1):
                input_text = gr.Textbox(
                    lines=5,
                    placeholder="Enter your marketing content here...",
                    label="Marketing Content",
                )
                analyze_btn = gr.Button("Analyze", variant="primary")
                gr.Examples(
                    examples=[
                        "WordLift is an AI-powered SEO tool",
                        "Our advanced machine learning algorithms optimize your content",
                        "Simple and effective website optimization",
                    ],
                    inputs=input_text,
                )

            with gr.Column(scale=2):
                output_text = gr.Markdown(label="Analysis Results")
                with gr.Box():
                    gr.Markdown("## Feature Dashboard")
                    feature_id_text = gr.Text(
                        label="Currently viewing feature", show_label=False
                    )
                    dashboard_frame = gr.HTML(label="Feature Dashboard")

        def update_dashboard(text):
            output, dashboard_url, feature_id = analyze(text)
            return (
                output,
                f"<iframe src='{dashboard_url}' width='100%' height='600px' frameborder='0'></iframe>",
                f"Currently viewing Feature {feature_id} - Most active feature in your content",
            )

        analyze_btn.click(
            fn=update_dashboard,
            inputs=input_text,
            outputs=[output_text, dashboard_frame, feature_id_text],
        )

    return interface


if __name__ == "__main__":
    iface = create_gradio_interface()
    iface.launch()