Varun Wadhwa commited on
Commit
5e89ee6
·
unverified ·
1 Parent(s): 2e13708

Copying over

Browse files
Files changed (2) hide show
  1. app.py +252 -0
  2. requirements.txt +15 -0
app.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ from datasets import load_dataset
4
+
5
+ import numpy as np
6
+ import os
7
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.optim as optim
12
+ from torch.utils.data import DataLoader
13
+
14
+ from transformers import AutoModelForTokenClassification, AutoTokenizer, DataCollatorForTokenClassification
15
+ from transformers import DebertaV2Config, DebertaV2ForTokenClassification
16
+
17
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
18
+
19
+ # print weights
20
+ def print_trainable_parameters(model):
21
+ pytorch_total_params = sum(p.numel() for p in model.parameters())
22
+ torch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
23
+ print(f'total params: {pytorch_total_params}. tunable params: {torch_total_params}')
24
+
25
+ device = torch.device('cpu')
26
+ print(f"Is CUDA available: {torch.cuda.is_available()}")
27
+ # True
28
+ if torch.cuda.is_available():
29
+ print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
30
+ device = torch.device('cuda')
31
+
32
+ # Load models
33
+ st.write('Loading the pretrained model ...')
34
+ teacher_model_name = "iiiorg/piiranha-v1-detect-personal-information"
35
+ teacher_model = AutoModelForTokenClassification.from_pretrained(teacher_model_name)
36
+ tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
37
+ print(teacher_model)
38
+ print_trainable_parameters(teacher_model)
39
+ label2id = teacher_model.config.label2id
40
+ id2label = teacher_model.config.id2label
41
+
42
+ st.write("id2label: ", id2label)
43
+ st.write("label2id: ", label2id)
44
+ dimension = len(id2label)
45
+ st.write("dimension", dimension)
46
+
47
+ student_model_config = teacher_model.config
48
+ student_model_config.num_attention_heads = 8
49
+ student_model_config.num_hidden_layers = 4
50
+ student_model = DebertaV2ForTokenClassification.from_pretrained(
51
+ "microsoft/mdeberta-v3-base",
52
+ config=student_model_config)
53
+ # ignore_mismatched_sizes=True)
54
+ print(student_model)
55
+ print_trainable_parameters(student_model)
56
+
57
+ if torch.cuda.is_available():
58
+ teacher_model = teacher_model.to(device)
59
+ student_model = student_model.to(device)
60
+
61
+ # Load data.
62
+ raw_dataset = load_dataset("ai4privacy/pii-masking-400k", split='train')
63
+ raw_dataset = raw_dataset.filter(lambda example: example["language"].startswith("en"))
64
+ #raw_dataset = raw_dataset.select(range(2000))
65
+ raw_dataset = raw_dataset.filter(lambda example, idx: idx % 11 == 0, with_indices=True)
66
+ raw_dataset = raw_dataset.train_test_split(test_size=0.2)
67
+ print(raw_dataset)
68
+ print(raw_dataset.column_names)
69
+
70
+ # inputs = tokenizer(
71
+ # raw_dataset['train'][0]['mbert_tokens'],
72
+ # truncation=True,
73
+ # is_split_into_words=True)
74
+ # print(inputs)
75
+ # print(inputs.tokens())
76
+ # print(inputs.word_ids())
77
+
78
+ # function to align labels with tokens
79
+ # --> special tokens: -100 label id (ignored by cross entropy),
80
+ # --> if tokens are inside a word, replace 'B-' with 'I-'
81
+ def align_labels_with_tokens(labels):
82
+ aligned_label_ids = []
83
+ aligned_label_ids.append(-100)
84
+ for i, label in enumerate(labels):
85
+ if label.startswith("B-"):
86
+ label = label.replace("B-", "I-")
87
+ aligned_label_ids.append(label2id[label])
88
+ aligned_label_ids.append(-100)
89
+ return aligned_label_ids
90
+
91
+ # create tokenize function
92
+ def tokenize_function(examples):
93
+ # tokenize and truncate text. The examples argument would have already stripped
94
+ # the train or test label.
95
+ new_labels = []
96
+ inputs = tokenizer(
97
+ examples['mbert_tokens'],
98
+ is_split_into_words=True,
99
+ padding=True,
100
+ truncation=True,
101
+ max_length=512)
102
+ for _, labels in enumerate(examples['mbert_token_classes']):
103
+ new_labels.append(align_labels_with_tokens(labels))
104
+
105
+ inputs["labels"] = new_labels
106
+ return inputs
107
+
108
+ # tokenize training and validation datasets
109
+ tokenized_data = raw_dataset.map(
110
+ tokenize_function,
111
+ batched=True)
112
+ tokenized_data.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
113
+ # data collator
114
+ data_collator = DataCollatorForTokenClassification(tokenizer)
115
+
116
+ st.write(tokenized_data["train"][:2]["labels"])
117
+
118
+ # Function to evaluate model performance
119
+ def evaluate_model(model, dataloader, device):
120
+ model.eval() # Set model to evaluation mode
121
+ all_preds = []
122
+ all_labels = []
123
+
124
+ # Disable gradient calculations
125
+ with torch.no_grad():
126
+ for batch in dataloader:
127
+ input_ids = batch['input_ids'].to(device)
128
+ attention_mask = batch['attention_mask'].to(device)
129
+ labels = batch['labels'].to(device)
130
+
131
+ # Forward pass to get logits
132
+ outputs = model(input_ids, attention_mask=attention_mask)
133
+ logits = outputs.logits
134
+
135
+ # Get predictions
136
+ preds = torch.argmax(logits, dim=-1).cpu().numpy()
137
+ all_preds.extend(preds)
138
+ all_labels.extend(labels.cpu().numpy())
139
+
140
+ # Calculate evaluation metrics
141
+ print("evaluate_model sizes")
142
+ print(len(all_preds[0]))
143
+ print(len(all_labels[0]))
144
+ all_preds = np.asarray(all_preds, dtype=np.float32)
145
+ all_labels = np.asarray(all_labels, dtype=np.float32)
146
+ print("Flattened sizes")
147
+ print(all_preds.size)
148
+ print(all_labels.size)
149
+ all_preds = all_preds.flatten()
150
+ all_labels = all_labels.flatten()
151
+ accuracy = accuracy_score(all_labels, all_preds)
152
+ precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='micro')
153
+
154
+ return accuracy, precision, recall, f1
155
+
156
+ # Function to compute distillation and hard-label loss
157
+ def distillation_loss(student_logits, teacher_logits, true_labels, temperature, alpha):
158
+ # print("Distillation loss sizes")
159
+ # print(teacher_logits.size())
160
+ # print(student_logits.size())
161
+ # print(true_labels.size())
162
+ # Compute soft targets from teacher logits
163
+ soft_targets = nn.functional.softmax(teacher_logits / temperature, dim=-1)
164
+ student_soft = nn.functional.log_softmax(student_logits / temperature, dim=-1)
165
+
166
+ # KL Divergence loss for distillation
167
+ distill_loss = nn.functional.kl_div(student_soft, soft_targets, reduction='batchmean') * (temperature ** 2)
168
+
169
+ # Cross-entropy loss for hard labels
170
+ student_logit_reshape = torch.transpose(student_logits, 1, 2) # transpose to match the labels dimension
171
+ hard_loss = nn.CrossEntropyLoss()(student_logit_reshape, true_labels)
172
+
173
+ # Combine losses
174
+ loss = alpha * distill_loss + (1.0 - alpha) * hard_loss
175
+
176
+ return loss
177
+
178
+ # hyperparameters
179
+ batch_size = 32
180
+ lr = 1e-4
181
+ num_epochs = 30
182
+ temperature = 2.0
183
+ alpha = 0.5
184
+
185
+ # define optimizer
186
+ optimizer = optim.Adam(student_model.parameters(), lr=lr)
187
+
188
+ # create training data loader
189
+ dataloader = DataLoader(tokenized_data['train'], batch_size=batch_size, collate_fn=data_collator)
190
+ # create testing data loader
191
+ test_dataloader = DataLoader(tokenized_data['test'], batch_size=batch_size, collate_fn=data_collator)
192
+
193
+ # put student model in train mode
194
+ student_model.train()
195
+
196
+ # train model
197
+ for epoch in range(num_epochs):
198
+ for batch in dataloader:
199
+ # Prepare inputs
200
+ input_ids = batch['input_ids'].to(device)
201
+ attention_mask = batch['attention_mask'].to(device)
202
+ labels = batch['labels'].to(device)
203
+
204
+ # Disable gradient calculation for teacher model
205
+ with torch.no_grad():
206
+ teacher_outputs = teacher_model(input_ids, attention_mask=attention_mask)
207
+ teacher_logits = teacher_outputs.logits
208
+
209
+ # Forward pass through the student model
210
+ student_outputs = student_model(input_ids, attention_mask=attention_mask)
211
+ student_logits = student_outputs.logits
212
+
213
+ # Compute the distillation loss
214
+ loss = distillation_loss(student_logits, teacher_logits, labels, temperature, alpha)
215
+
216
+ # Backpropagation
217
+ optimizer.zero_grad()
218
+ loss.backward()
219
+ optimizer.step()
220
+
221
+ print(f"Epoch {epoch + 1} completed with loss: {loss.item()}")
222
+
223
+ # Evaluate the teacher model
224
+ teacher_accuracy, teacher_precision, teacher_recall, teacher_f1 = evaluate_model(teacher_model, test_dataloader, device)
225
+ print(f"Teacher (test) - Accuracy: {teacher_accuracy:.4f}, Precision: {teacher_precision:.4f}, Recall: {teacher_recall:.4f}, F1 Score: {teacher_f1:.4f}")
226
+
227
+ # Evaluate the student model
228
+ student_accuracy, student_precision, student_recall, student_f1 = evaluate_model(student_model, test_dataloader, device)
229
+ print(f"Student (test) - Accuracy: {student_accuracy:.4f}, Precision: {student_precision:.4f}, Recall: {student_recall:.4f}, F1 Score: {student_f1:.4f}")
230
+ print("\n")
231
+
232
+ # put student model back into train mode
233
+ student_model.train()
234
+
235
+ #Compare the models
236
+ # create testing data loader
237
+ validation_dataloader = DataLoader(tokenized_data['test'], batch_size=8, collate_fn=data_collator)
238
+ # Evaluate the teacher model
239
+ teacher_accuracy, teacher_precision, teacher_recall, teacher_f1 = evaluate_model(teacher_model, validation_dataloader, device)
240
+ print(f"Teacher (validation) - Accuracy: {teacher_accuracy:.4f}, Precision: {teacher_precision:.4f}, Recall: {teacher_recall:.4f}, F1 Score: {teacher_f1:.4f}")
241
+ # Evaluate the student model
242
+ student_accuracy, student_precision, student_recall, student_f1 = evaluate_model(student_model, validation_dataloader, device)
243
+ print(f"Student (validation) - Accuracy: {student_accuracy:.4f}, Precision: {student_precision:.4f}, Recall: {student_recall:.4f}, F1 Score: {student_f1:.4f}")
244
+
245
+
246
+ st.write('Pushing model to huggingface')
247
+
248
+ # Push model to huggingface
249
+ hf_name = 'CarolXia' # your hf username or org name
250
+ mode_name = "pii-kd-deberta-v2"
251
+ model_id = hf_name + "/" + mode_name
252
+ student_model.push_to_hub(model_id, token=st.secrets["HUGGINGFACE_TOKEN"])
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ auto-gptq
2
+ bitsandbytes
3
+ datasets
4
+ evaluate
5
+ seqeval
6
+ gliner
7
+ torch>=2.0.0
8
+ transformers>=4.38.2
9
+ huggingface_hub>=0.21.4
10
+ onnxruntime
11
+ optimum
12
+ peft
13
+ sentencepiece
14
+ tqdm
15
+