Varun Wadhwa
Logs
bfb1e05 unverified
raw
history blame
11.3 kB
import streamlit as st
from datasets import load_dataset
import numpy as np
import os
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import AutoModelForTokenClassification, AutoTokenizer, DataCollatorForTokenClassification
from transformers import DebertaV2Config, DebertaV2ForTokenClassification
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# print weights
def print_trainable_parameters(model):
pytorch_total_params = sum(p.numel() for p in model.parameters())
torch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'total params: {pytorch_total_params}. tunable params: {torch_total_params}')
device = torch.device('cpu')
print(f"Is CUDA available: {torch.cuda.is_available()}")
# True
if torch.cuda.is_available():
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
device = torch.device('cuda')
# Load models
st.write('Loading the pretrained model ...')
teacher_model_name = "iiiorg/piiranha-v1-detect-personal-information"
teacher_model = AutoModelForTokenClassification.from_pretrained(teacher_model_name)
tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
print(teacher_model)
print_trainable_parameters(teacher_model)
label2id = teacher_model.config.label2id
id2label = teacher_model.config.id2label
st.write("id2label: ", id2label)
st.write("label2id: ", label2id)
dimension = len(id2label)
st.write("dimension", dimension)
student_model_config = teacher_model.config
student_model_config.num_attention_heads = 8
student_model_config.num_hidden_layers = 4
student_model = DebertaV2ForTokenClassification.from_pretrained(
"microsoft/mdeberta-v3-base",
config=student_model_config)
# ignore_mismatched_sizes=True)
print(student_model)
print_trainable_parameters(student_model)
if torch.cuda.is_available():
teacher_model = teacher_model.to(device)
student_model = student_model.to(device)
# Load data.
raw_dataset = load_dataset("ai4privacy/pii-masking-400k", split='train')
raw_dataset = raw_dataset.filter(lambda example: example["language"].startswith("en"))
#raw_dataset = raw_dataset.select(range(2000))
raw_dataset = raw_dataset.filter(lambda example, idx: idx % 11 == 0, with_indices=True)
raw_dataset = raw_dataset.train_test_split(test_size=0.2)
print(raw_dataset)
print(raw_dataset.column_names)
# inputs = tokenizer(
# raw_dataset['train'][0]['mbert_tokens'],
# truncation=True,
# is_split_into_words=True)
# print(inputs)
# print(inputs.tokens())
# print(inputs.word_ids())
# function to align labels with tokens
# --> special tokens: -100 label id (ignored by cross entropy),
# --> if tokens are inside a word, replace 'B-' with 'I-'
def align_labels_with_tokens(label, word_ids):
aligned_label_ids = []
previous_word_idx = None
for word_idx in word_ids: # Set the special tokens to -100.
if word_idx is None:
aligned_label_ids.append(-100)
elif word_idx != previous_word_idx: # Only label the first token of a given word.
if label.startswith("B-"):
print(word_idx)
print(label)
label = label.replace("B-", "I-")
aligned_label_ids.append(label[word_idx])
else:
aligned_label_ids.append(-100)
previous_word_idx = word_idx
return aligned_label_ids
# create tokenize function
def tokenize_function(examples):
# tokenize and truncate text. The examples argument would have already stripped
# the train or test label.
new_labels = []
inputs = tokenizer(
examples['mbert_tokens'],
is_split_into_words=True,
padding=True,
truncation=True,
max_length=512)
for i, label in enumerate(examples['mbert_token_classes']):
word_ids = inputs.word_ids(batch_index=i)
new_labels.append(align_labels_with_tokens(label, word_ids))
print("Printing partial input with tokenized output")
print(inputs.tokens()[:1000])
print(inputs.word_ids()[:1000])
print(new_labels[:1000])
inputs["labels"] = new_labels
return inputs
# tokenize training and validation datasets
tokenized_data = raw_dataset.map(
tokenize_function,
batched=True)
tokenized_data.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
# data collator
data_collator = DataCollatorForTokenClassification(tokenizer)
st.write(tokenized_data["train"][:2]["labels"])
# Function to evaluate model performance
def evaluate_model(model, dataloader, device):
model.eval() # Set model to evaluation mode
all_preds = []
all_labels = []
sample_count = 0
num_samples=100
# Disable gradient calculations
with torch.no_grad():
for batch in dataloader:
input_ids = batch['input_ids'].to(device)
current_batch_size = input_ids.size(0)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device).cpu().numpy()
# Forward pass to get logits
outputs = model(input_ids, attention_mask=attention_mask)
logits = outputs.logits
# Get predictions
preds = torch.argmax(logits, dim=-1).cpu().numpy()
# Use attention mask to get valid tokens
mask = batch['attention_mask'].cpu().numpy().astype(bool)
# Process each sequence in the batch
for i in range(current_batch_size):
valid_preds = preds[i][mask[i]].flatten()
valid_labels = labels[i][mask[i]].flatten()
all_preds.extend(valid_preds.tolist())
all_labels.extend(valid_labels.tolist())
if sample_count < num_samples:
print(f"Sample {sample_count + 1}:")
print(f"Tokens: {tokenizer.convert_ids_to_tokens(input_ids[i])}")
print(f"True Labels: {[id2label[label] for label in valid_labels]}")
print(f"Predicted Labels: {[id2label[pred] for pred in valid_preds]}")
print("-" * 50)
sample_count += 1
# Calculate evaluation metrics
print("evaluate_model sizes")
print(len(all_preds))
print(len(all_labels))
all_preds = np.asarray(all_preds, dtype=np.float32)
all_labels = np.asarray(all_labels, dtype=np.float32)
accuracy = accuracy_score(all_labels, all_preds)
precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='micro')
return accuracy, precision, recall, f1
# Function to compute distillation and hard-label loss
def distillation_loss(student_logits, teacher_logits, true_labels, temperature, alpha):
# print("Distillation loss sizes")
# print(teacher_logits.size())
# print(student_logits.size())
# print(true_labels.size())
# Compute soft targets from teacher logits
soft_targets = nn.functional.softmax(teacher_logits / temperature, dim=-1)
student_soft = nn.functional.log_softmax(student_logits / temperature, dim=-1)
# KL Divergence loss for distillation
distill_loss = nn.functional.kl_div(student_soft, soft_targets, reduction='batchmean') * (temperature ** 2)
# Cross-entropy loss for hard labels
student_logit_reshape = torch.transpose(student_logits, 1, 2) # transpose to match the labels dimension
hard_loss = nn.CrossEntropyLoss()(student_logit_reshape, true_labels)
# Combine losses
loss = alpha * distill_loss + (1.0 - alpha) * hard_loss
return loss
# hyperparameters
batch_size = 32
lr = 1e-4
num_epochs = 30
temperature = 2.0
alpha = 0.5
# define optimizer
optimizer = optim.Adam(student_model.parameters(), lr=lr)
# create training data loader
dataloader = DataLoader(tokenized_data['train'], batch_size=batch_size, collate_fn=data_collator)
# create testing data loader
test_dataloader = DataLoader(tokenized_data['test'], batch_size=batch_size, collate_fn=data_collator)
untrained_student_accuracy, untrained_student_precision, untrained_student_recall, untrained_student_f1 = evaluate_model(student_model, test_dataloader, device)
print(f"Untrained Student (test) - Accuracy: {untrained_student_accuracy:.4f}, Precision: {untrained_student_precision:.4f}, Recall: {untrained_student_recall:.4f}, F1 Score: {untrained_student_f1:.4f}")
# put student model in train mode
student_model.train()
# train model
for epoch in range(num_epochs):
for batch in dataloader:
# Prepare inputs
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
# Disable gradient calculation for teacher model
with torch.no_grad():
teacher_outputs = teacher_model(input_ids, attention_mask=attention_mask)
teacher_logits = teacher_outputs.logits
# Forward pass through the student model
student_outputs = student_model(input_ids, attention_mask=attention_mask)
student_logits = student_outputs.logits
# Compute the distillation loss
loss = distillation_loss(student_logits, teacher_logits, labels, temperature, alpha)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch {epoch + 1} completed with loss: {loss.item()}")
# Evaluate the teacher model
teacher_accuracy, teacher_precision, teacher_recall, teacher_f1 = evaluate_model(teacher_model, test_dataloader, device)
print(f"Teacher (test) - Accuracy: {teacher_accuracy:.4f}, Precision: {teacher_precision:.4f}, Recall: {teacher_recall:.4f}, F1 Score: {teacher_f1:.4f}")
# Evaluate the student model
student_accuracy, student_precision, student_recall, student_f1 = evaluate_model(student_model, test_dataloader, device)
print(f"Student (test) - Accuracy: {student_accuracy:.4f}, Precision: {student_precision:.4f}, Recall: {student_recall:.4f}, F1 Score: {student_f1:.4f}")
print("\n")
# put student model back into train mode
student_model.train()
#Compare the models
# create testing data loader
validation_dataloader = DataLoader(tokenized_data['test'], batch_size=8, collate_fn=data_collator)
# Evaluate the teacher model
teacher_accuracy, teacher_precision, teacher_recall, teacher_f1 = evaluate_model(teacher_model, validation_dataloader, device)
print(f"Teacher (validation) - Accuracy: {teacher_accuracy:.4f}, Precision: {teacher_precision:.4f}, Recall: {teacher_recall:.4f}, F1 Score: {teacher_f1:.4f}")
# Evaluate the student model
student_accuracy, student_precision, student_recall, student_f1 = evaluate_model(student_model, validation_dataloader, device)
print(f"Student (validation) - Accuracy: {student_accuracy:.4f}, Precision: {student_precision:.4f}, Recall: {student_recall:.4f}, F1 Score: {student_f1:.4f}")
st.write('Pushing model to huggingface')
# Push model to huggingface
hf_name = 'CarolXia' # your hf username or org name
mode_name = "pii-kd-deberta-v2"
model_id = hf_name + "/" + mode_name
student_model.push_to_hub(model_id, token=st.secrets["HUGGINGFACE_TOKEN"])