Spaces:
Sleeping
Sleeping
import numpy as np | |
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score | |
import torch | |
from torch.utils.data import DataLoader | |
from torch.nn.utils.rnn import pad_sequence | |
from tqdm import tqdm | |
from transformers import ( | |
DataCollatorForSeq2Seq, AutoTokenizer, AutoModelForSeq2SeqLM, | |
Seq2SeqTrainingArguments, Trainer, Seq2SeqTrainer | |
) | |
class T5Generator: | |
def __init__(self, model_checkpoint): | |
self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) | |
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint) | |
self.data_collator = DataCollatorForSeq2Seq(self.tokenizer) | |
self.device = 'cuda' if torch.backends.cuda.is_built() else ('mps' if torch.backends.mps.is_built() else 'cpu') | |
def tokenize_function_inputs(self, sample): | |
""" | |
Udf to tokenize the input dataset. | |
""" | |
model_inputs = self.tokenizer(sample['text'], max_length=512, truncation=True) | |
labels = self.tokenizer(sample["labels"], max_length=64, truncation=True) | |
model_inputs["labels"] = labels["input_ids"] | |
return model_inputs | |
def train(self, tokenized_datasets, **kwargs): | |
""" | |
Train the generative model. | |
""" | |
#Set training arguments | |
args = Seq2SeqTrainingArguments( | |
**kwargs | |
) | |
# Define trainer object | |
trainer = Seq2SeqTrainer( | |
self.model, | |
args, | |
train_dataset=tokenized_datasets["train"], | |
eval_dataset=tokenized_datasets["validation"] if tokenized_datasets.get("validation") is not None else None, | |
tokenizer=self.tokenizer, | |
data_collator=self.data_collator, | |
) | |
print("Trainer device:", trainer.args.device) | |
# Finetune the model | |
torch.cuda.empty_cache() | |
print('\nModel training started ....') | |
trainer.train() | |
# Save best model | |
trainer.save_model() | |
return trainer | |
def get_labels(self, tokenized_dataset, batch_size = 4, max_length = 128, sample_set = 'train'): | |
""" | |
Get the predictions from the trained model. | |
""" | |
def collate_fn(batch): | |
input_ids = [torch.tensor(example['input_ids']) for example in batch] | |
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) | |
return input_ids | |
dataloader = DataLoader(tokenized_dataset[sample_set], batch_size=batch_size, collate_fn=collate_fn) | |
predicted_output = [] | |
self.model.to(self.device) | |
print('Model loaded to: ', self.device) | |
for batch in tqdm(dataloader): | |
batch = batch.to(self.device) | |
output_ids = self.model.generate(batch, max_length = max_length) | |
output_texts = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True) | |
for output_text in output_texts: | |
predicted_output.append(output_text) | |
return predicted_output | |
def get_metrics(self, y_true, y_pred, is_triplet_extraction=False): | |
total_pred = 0 | |
total_gt = 0 | |
tp = 0 | |
if not is_triplet_extraction: | |
for gt, pred in zip(y_true, y_pred): | |
gt_list = gt.split(', ') | |
pred_list = pred.split(', ') | |
total_pred+=len(pred_list) | |
total_gt+=len(gt_list) | |
for gt_val in gt_list: | |
for pred_val in pred_list: | |
if pred_val in gt_val or gt_val in pred_val: | |
tp+=1 | |
break | |
else: | |
for gt, pred in zip(y_true, y_pred): | |
gt_list = gt.split(', ') | |
pred_list = pred.split(', ') | |
total_pred+=len(pred_list) | |
total_gt+=len(gt_list) | |
for gt_val in gt_list: | |
gt_asp = gt_val.split(':')[0] | |
try: | |
gt_op = gt_val.split(':')[1] | |
except: | |
continue | |
try: | |
gt_sent = gt_val.split(':')[2] | |
except: | |
continue | |
for pred_val in pred_list: | |
pr_asp = pred_val.split(':')[0] | |
try: | |
pr_op = pred_val.split(':')[1] | |
except: | |
continue | |
try: | |
pr_sent = gt_val.split(':')[2] | |
except: | |
continue | |
if pr_asp in gt_asp and pr_op in gt_op and gt_sent == pr_sent: | |
tp+=1 | |
p = tp/total_pred | |
r = tp/total_gt | |
return p, r, 2*p*r/(p+r), None | |
class T5Classifier: | |
def __init__(self, model_checkpoint): | |
self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, force_download = True) | |
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint, force_download = True) | |
self.data_collator = DataCollatorForSeq2Seq(self.tokenizer) | |
self.device = 'cuda' if torch.has_cuda else ('mps' if torch.has_mps else 'cpu') | |
def tokenize_function_inputs(self, sample): | |
""" | |
Udf to tokenize the input dataset. | |
""" | |
sample['input_ids'] = self.tokenizer(sample["text"], max_length = 512, truncation = True).input_ids | |
sample['labels'] = self.tokenizer(sample["labels"], max_length = 64, truncation = True).input_ids | |
return sample | |
def train(self, tokenized_datasets, **kwargs): | |
""" | |
Train the generative model. | |
""" | |
# Set training arguments | |
args = Seq2SeqTrainingArguments( | |
**kwargs | |
) | |
# Define trainer object | |
trainer = Trainer( | |
self.model, | |
args, | |
train_dataset=tokenized_datasets["train"], | |
eval_dataset=tokenized_datasets["validation"] if tokenized_datasets.get("validation") is not None else None, | |
tokenizer=self.tokenizer, | |
data_collator = self.data_collator | |
) | |
print("Trainer device:", trainer.args.device) | |
# Finetune the model | |
torch.cuda.empty_cache() | |
print('\nModel training started ....') | |
trainer.train() | |
# Save best model | |
trainer.save_model() | |
return trainer | |
def get_labels(self, tokenized_dataset, batch_size = 4, sample_set = 'train'): | |
""" | |
Get the predictions from the trained model. | |
""" | |
def collate_fn(batch): | |
input_ids = [torch.tensor(example['input_ids']) for example in batch] | |
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) | |
return input_ids | |
dataloader = DataLoader(tokenized_dataset[sample_set], batch_size=batch_size, collate_fn=collate_fn) | |
predicted_output = [] | |
self.model.to(self.device) | |
print('Model loaded to: ', self.device) | |
for batch in tqdm(dataloader): | |
batch = batch.to(self.device) | |
output_ids = self.model.generate(batch) | |
output_texts = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True) | |
for output_text in output_texts: | |
predicted_output.append(output_text) | |
return predicted_output | |
def get_metrics(self, y_true, y_pred): | |
return precision_score(y_true, y_pred, average='macro'), recall_score(y_true, y_pred, average='macro'), \ | |
f1_score(y_true, y_pred, average='macro'), accuracy_score(y_true, y_pred) |