File size: 9,930 Bytes
f85532f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deaf693
f85532f
 
deaf693
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f85532f
deaf693
f85532f
deaf693
f85532f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deaf693
 
 
 
 
 
 
 
 
 
 
 
f85532f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download
import numpy as np
from dataclasses import dataclass
from typing import List, Dict, Optional
import logging

# Initialize logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class MarketingFeature:
    """Structure to hold marketing-relevant feature information"""
    feature_id: int
    name: str
    category: str
    description: str
    interpretation_guide: str
    layer: int
    threshold: float = 0.1

# Define marketing-relevant features from Gemma Scope
MARKETING_FEATURES = [
    MarketingFeature(
        feature_id=35,
        name="Technical Term Detector",
        category="technical",
        description="Detects technical and specialized terminology",
        interpretation_guide="High activation indicates strong technical focus",
        layer=20
    ),
    MarketingFeature(
        feature_id=6680,
        name="Compound Technical Terms",
        category="technical",
        description="Identifies complex technical concepts",
        interpretation_guide="Consider simplifying language if activation is too high",
        layer=20
    ),
    MarketingFeature(
        feature_id=2,
        name="SEO Keyword Detector",
        category="seo",
        description="Identifies potential SEO keywords",
        interpretation_guide="High activation suggests strong SEO potential",
        layer=20
    ),
    # Add more relevant features as we discover them
]

class MarketingAnalyzer:
    """Main class for analyzing marketing content using Gemma Scope"""
    
    def __init__(self, model_size: str = "2b"):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self._initialize_model(model_size)
        self._load_saes()

    def _initialize_model(self, model_size: str):
        """Initialize Gemma model and tokenizer"""
        try:
            import os
            model_name = f"google/gemma-{model_size}"
            
            # Access HF token from environment variable
            hf_token = os.environ.get('HF_TOKEN')
            if not hf_token:
                logger.warning("HF_TOKEN not found in environment variables")
                
            # Initialize model and tokenizer with token
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name,
                token=hf_token,
                device_map='auto'  # Automatically handle device placement
            )
            self.tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                token=hf_token
            )
            
            self.model.eval()  # Set to evaluation mode
            logger.info(f"Initialized model: {model_name}")
            
        except Exception as e:
            logger.error(f"Error initializing model: {str(e)}")
            raise

    def _load_saes(self):
        """Load relevant SAEs from Gemma Scope"""
        self.saes = {}
        for feature in MARKETING_FEATURES:
            try:
                # Load SAE parameters for each feature
                path = hf_hub_download(
                    repo_id=f"google/gemma-scope-{self.model_size}-pt-res",
                    filename=f"layer_{feature.layer}/width_16k/average_l0_71/params.npz"
                )
                params = np.load(path)
                self.saes[feature.feature_id] = {
                    'params': {k: torch.from_numpy(v).cuda() for k, v in params.items()},
                    'feature': feature
                }
                logger.info(f"Loaded SAE for feature {feature.feature_id}")
            except Exception as e:
                logger.error(f"Error loading SAE for feature {feature.feature_id}: {str(e)}")
                continue

    def analyze_content(self, text: str) -> Dict:
        """Analyze marketing content using loaded SAEs"""
        results = {
            'text': text,
            'features': {},
            'categories': {},
            'recommendations': []
        }
        
        try:
            # Get model activations
            inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
            with torch.no_grad():
                outputs = self.model(**inputs, output_hidden_states=True)
            
            # Analyze each feature
            for feature_id, sae_data in self.saes.items():
                feature = sae_data['feature']
                layer_output = outputs.hidden_states[feature.layer]
                
                # Apply SAE
                activations = self._apply_sae(
                    layer_output,
                    sae_data['params'],
                    feature.threshold
                )
                
                # Record results
                feature_result = {
                    'name': feature.name,
                    'category': feature.category,
                    'activation_score': float(activations.mean()),
                    'max_activation': float(activations.max()),
                    'interpretation': self._interpret_activation(
                        activations,
                        feature
                    )
                }
                
                results['features'][feature_id] = feature_result
                
                # Aggregate by category
                if feature.category not in results['categories']:
                    results['categories'][feature.category] = []
                results['categories'][feature.category].append(feature_result)
            
            # Generate recommendations
            results['recommendations'] = self._generate_recommendations(results)
            
        except Exception as e:
            logger.error(f"Error analyzing content: {str(e)}")
            raise
        
        return results

    def _apply_sae(
        self,
        activations: torch.Tensor,
        sae_params: Dict[str, torch.Tensor],
        threshold: float
    ) -> torch.Tensor:
        """Apply SAE to get feature activations"""
        pre_acts = activations @ sae_params['W_enc'] + sae_params['b_enc']
        mask = pre_acts > sae_params['threshold']
        acts = mask * torch.nn.functional.relu(pre_acts)
        return acts

    def _interpret_activation(
        self,
        activations: torch.Tensor,
        feature: MarketingFeature
    ) -> str:
        """Interpret activation patterns for a feature"""
        mean_activation = float(activations.mean())
        if mean_activation > 0.8:
            return f"Very strong presence of {feature.name.lower()}"
        elif mean_activation > 0.5:
            return f"Moderate presence of {feature.name.lower()}"
        else:
            return f"Limited presence of {feature.name.lower()}"

    def _generate_recommendations(self, results: Dict) -> List[str]:
        """Generate content recommendations based on analysis"""
        recommendations = []
        
        # Analyze technical complexity
        tech_score = np.mean([
            f['activation_score'] for f in results['features'].values()
            if f['category'] == 'technical'
        ])
        if tech_score > 0.8:
            recommendations.append(
                "Consider simplifying technical language for broader audience"
            )
        elif tech_score < 0.3:
            recommendations.append(
                "Could benefit from more specific technical details"
            )
        
        # Add more recommendation logic as needed
        return recommendations

def create_gradio_interface():
    """Create Gradio interface for marketing analysis"""
    try:
        analyzer = MarketingAnalyzer()
    except Exception as e:
        logger.error(f"Failed to initialize analyzer: {str(e)}")
        # Provide a more graceful fallback or error message in the interface
        return gr.Interface(
            fn=lambda x: "Error: Failed to initialize model. Please check authentication.",
            inputs=gr.Textbox(),
            outputs=gr.Textbox(),
            title="Marketing Content Analyzer (Error)",
            description="Failed to initialize. Please check if HF_TOKEN is properly set."
        )
    
    def analyze(text):
        results = analyzer.analyze_content(text)
        
        # Format results for display
        output = "Content Analysis Results\n\n"
        
        # Overall category scores
        output += "Category Scores:\n"
        for category, features in results['categories'].items():
            avg_score = np.mean([f['activation_score'] for f in features])
            output += f"{category.title()}: {avg_score:.2f}\n"
        
        # Feature details
        output += "\nFeature Details:\n"
        for feature_id, feature in results['features'].items():
            output += f"\n{feature['name']}:\n"
            output += f"Score: {feature['activation_score']:.2f}\n"
            output += f"Interpretation: {feature['interpretation']}\n"
        
        # Recommendations
        output += "\nRecommendations:\n"
        for rec in results['recommendations']:
            output += f"- {rec}\n"
        
        return output
    
    iface = gr.Interface(
        fn=analyze,
        inputs=gr.Textbox(
            lines=5,
            placeholder="Enter your marketing content here..."
        ),
        outputs=gr.Textbox(),
        title="Marketing Content Analyzer",
        description="Analyze your marketing content using Gemma Scope's neural features",
        examples=[
            ["WordLift is an AI-powered SEO tool"],
            ["Our advanced machine learning algorithms optimize your content"],
            ["Simple and effective website optimization"]
        ]
    )
    
    return iface

if __name__ == "__main__":
    iface = create_gradio_interface()
    iface.launch()