import streamlit as st from datasets import load_dataset import numpy as np import os from sklearn.metrics import accuracy_score, precision_recall_fscore_support import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from transformers import AutoModelForTokenClassification, AutoTokenizer, DataCollatorForTokenClassification from transformers import DebertaV2Config, DebertaV2ForTokenClassification os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # print weights def print_trainable_parameters(model): pytorch_total_params = sum(p.numel() for p in model.parameters()) torch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f'total params: {pytorch_total_params}. tunable params: {torch_total_params}') device = torch.device('cpu') print(f"Is CUDA available: {torch.cuda.is_available()}") # True if torch.cuda.is_available(): print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") device = torch.device('cuda') # Load models st.write('Loading the pretrained model ...') teacher_model_name = "iiiorg/piiranha-v1-detect-personal-information" teacher_model = AutoModelForTokenClassification.from_pretrained(teacher_model_name) tokenizer = AutoTokenizer.from_pretrained(teacher_model_name) print(teacher_model) print_trainable_parameters(teacher_model) label2id = teacher_model.config.label2id id2label = teacher_model.config.id2label st.write("id2label: ", id2label) st.write("label2id: ", label2id) dimension = len(id2label) st.write("dimension", dimension) student_model_config = teacher_model.config student_model_config.num_attention_heads = 8 student_model_config.num_hidden_layers = 4 student_model = DebertaV2ForTokenClassification.from_pretrained( "microsoft/mdeberta-v3-base", config=student_model_config) # ignore_mismatched_sizes=True) print(student_model) print_trainable_parameters(student_model) if torch.cuda.is_available(): teacher_model = teacher_model.to(device) student_model = student_model.to(device) # Load data. raw_dataset = load_dataset("ai4privacy/pii-masking-400k", split='train') raw_dataset = raw_dataset.filter(lambda example: example["language"].startswith("en")) #raw_dataset = raw_dataset.select(range(2000)) raw_dataset = raw_dataset.filter(lambda example, idx: idx % 11 == 0, with_indices=True) raw_dataset = raw_dataset.train_test_split(test_size=0.2) print(raw_dataset) print(raw_dataset.column_names) # inputs = tokenizer( # raw_dataset['train'][0]['mbert_tokens'], # truncation=True, # is_split_into_words=True) # print(inputs) # print(inputs.tokens()) # print(inputs.word_ids()) # function to align labels with tokens # --> special tokens: -100 label id (ignored by cross entropy), # --> if tokens are inside a word, replace 'B-' with 'I-' def align_labels_with_tokens(label, word_ids): aligned_label_ids = [] previous_word_idx = None for word_idx in word_ids: # Set the special tokens to -100. if word_idx is None: aligned_label_ids.append(-100) elif word_idx != previous_word_idx: # Only label the first token of a given word. if label.startswith("B-"): print(word_idx) print(label) label = label.replace("B-", "I-") aligned_label_ids.append(label[word_idx]) else: aligned_label_ids.append(-100) previous_word_idx = word_idx return aligned_label_ids # create tokenize function def tokenize_function(examples): # tokenize and truncate text. The examples argument would have already stripped # the train or test label. new_labels = [] inputs = tokenizer( examples['mbert_tokens'], is_split_into_words=True, padding=True, truncation=True, max_length=512) for i, label in enumerate(examples['mbert_token_classes']): word_ids = inputs.word_ids(batch_index=i) new_labels.append(align_labels_with_tokens(label, word_ids)) print("Printing partial input with tokenized output") print(inputs.tokens()[:1000]) print(inputs.word_ids()[:1000]) print(new_labels[:1000]) inputs["labels"] = new_labels return inputs # tokenize training and validation datasets tokenized_data = raw_dataset.map( tokenize_function, batched=True) tokenized_data.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels']) # data collator data_collator = DataCollatorForTokenClassification(tokenizer) st.write(tokenized_data["train"][:2]["labels"]) # Function to evaluate model performance def evaluate_model(model, dataloader, device): model.eval() # Set model to evaluation mode all_preds = [] all_labels = [] sample_count = 0 num_samples=100 # Disable gradient calculations with torch.no_grad(): for batch in dataloader: input_ids = batch['input_ids'].to(device) current_batch_size = input_ids.size(0) attention_mask = batch['attention_mask'].to(device) labels = batch['labels'].to(device).cpu().numpy() # Forward pass to get logits outputs = model(input_ids, attention_mask=attention_mask) logits = outputs.logits # Get predictions preds = torch.argmax(logits, dim=-1).cpu().numpy() # Use attention mask to get valid tokens mask = batch['attention_mask'].cpu().numpy().astype(bool) # Process each sequence in the batch for i in range(current_batch_size): valid_preds = preds[i][mask[i]].flatten() valid_labels = labels[i][mask[i]].flatten() all_preds.extend(valid_preds.tolist()) all_labels.extend(valid_labels.tolist()) if sample_count < num_samples: print(f"Sample {sample_count + 1}:") print(f"Tokens: {tokenizer.convert_ids_to_tokens(input_ids[i])}") print(f"True Labels: {[id2label[label] for label in valid_labels]}") print(f"Predicted Labels: {[id2label[pred] for pred in valid_preds]}") print("-" * 50) sample_count += 1 # Calculate evaluation metrics print("evaluate_model sizes") print(len(all_preds)) print(len(all_labels)) all_preds = np.asarray(all_preds, dtype=np.float32) all_labels = np.asarray(all_labels, dtype=np.float32) accuracy = accuracy_score(all_labels, all_preds) precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='micro') return accuracy, precision, recall, f1 # Function to compute distillation and hard-label loss def distillation_loss(student_logits, teacher_logits, true_labels, temperature, alpha): # print("Distillation loss sizes") # print(teacher_logits.size()) # print(student_logits.size()) # print(true_labels.size()) # Compute soft targets from teacher logits soft_targets = nn.functional.softmax(teacher_logits / temperature, dim=-1) student_soft = nn.functional.log_softmax(student_logits / temperature, dim=-1) # KL Divergence loss for distillation distill_loss = nn.functional.kl_div(student_soft, soft_targets, reduction='batchmean') * (temperature ** 2) # Cross-entropy loss for hard labels student_logit_reshape = torch.transpose(student_logits, 1, 2) # transpose to match the labels dimension hard_loss = nn.CrossEntropyLoss()(student_logit_reshape, true_labels) # Combine losses loss = alpha * distill_loss + (1.0 - alpha) * hard_loss return loss # hyperparameters batch_size = 32 lr = 1e-4 num_epochs = 30 temperature = 2.0 alpha = 0.5 # define optimizer optimizer = optim.Adam(student_model.parameters(), lr=lr) # create training data loader dataloader = DataLoader(tokenized_data['train'], batch_size=batch_size, collate_fn=data_collator) # create testing data loader test_dataloader = DataLoader(tokenized_data['test'], batch_size=batch_size, collate_fn=data_collator) untrained_student_accuracy, untrained_student_precision, untrained_student_recall, untrained_student_f1 = evaluate_model(student_model, test_dataloader, device) print(f"Untrained Student (test) - Accuracy: {untrained_student_accuracy:.4f}, Precision: {untrained_student_precision:.4f}, Recall: {untrained_student_recall:.4f}, F1 Score: {untrained_student_f1:.4f}") # put student model in train mode student_model.train() # train model for epoch in range(num_epochs): for batch in dataloader: # Prepare inputs input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) labels = batch['labels'].to(device) # Disable gradient calculation for teacher model with torch.no_grad(): teacher_outputs = teacher_model(input_ids, attention_mask=attention_mask) teacher_logits = teacher_outputs.logits # Forward pass through the student model student_outputs = student_model(input_ids, attention_mask=attention_mask) student_logits = student_outputs.logits # Compute the distillation loss loss = distillation_loss(student_logits, teacher_logits, labels, temperature, alpha) # Backpropagation optimizer.zero_grad() loss.backward() optimizer.step() print(f"Epoch {epoch + 1} completed with loss: {loss.item()}") # Evaluate the teacher model teacher_accuracy, teacher_precision, teacher_recall, teacher_f1 = evaluate_model(teacher_model, test_dataloader, device) print(f"Teacher (test) - Accuracy: {teacher_accuracy:.4f}, Precision: {teacher_precision:.4f}, Recall: {teacher_recall:.4f}, F1 Score: {teacher_f1:.4f}") # Evaluate the student model student_accuracy, student_precision, student_recall, student_f1 = evaluate_model(student_model, test_dataloader, device) print(f"Student (test) - Accuracy: {student_accuracy:.4f}, Precision: {student_precision:.4f}, Recall: {student_recall:.4f}, F1 Score: {student_f1:.4f}") print("\n") # put student model back into train mode student_model.train() #Compare the models # create testing data loader validation_dataloader = DataLoader(tokenized_data['test'], batch_size=8, collate_fn=data_collator) # Evaluate the teacher model teacher_accuracy, teacher_precision, teacher_recall, teacher_f1 = evaluate_model(teacher_model, validation_dataloader, device) print(f"Teacher (validation) - Accuracy: {teacher_accuracy:.4f}, Precision: {teacher_precision:.4f}, Recall: {teacher_recall:.4f}, F1 Score: {teacher_f1:.4f}") # Evaluate the student model student_accuracy, student_precision, student_recall, student_f1 = evaluate_model(student_model, validation_dataloader, device) print(f"Student (validation) - Accuracy: {student_accuracy:.4f}, Precision: {student_precision:.4f}, Recall: {student_recall:.4f}, F1 Score: {student_f1:.4f}") st.write('Pushing model to huggingface') # Push model to huggingface hf_name = 'CarolXia' # your hf username or org name mode_name = "pii-kd-deberta-v2" model_id = hf_name + "/" + mode_name student_model.push_to_hub(model_id, token=st.secrets["HUGGINGFACE_TOKEN"])