from torch import nn
import torch
import numpy as np
from copy import deepcopy
import re
import unicodedata
from torch.utils.data import Dataset, DataLoader,TensorDataset, RandomSampler
from sklearn.model_selection import train_test_split
from torch.optim import Adam
from copy import deepcopy
import gc
import torch
import numpy as np
from torchmetrics import functional as fn
import random
# Pre-trained model
class Encoder(nn.Module):
def __init__(self, layers, freeze_bert, model):
super(Encoder, self).__init__()
# Dummy Parameter
self.dummy_param = nn.Parameter(torch.empty(0))
# Pre-trained model
self.model = deepcopy(model)
# Freezing bert parameters
if freeze_bert:
for param in self.model.parameters():
param.requires_grad = freeze_bert
# Selecting hidden layers of the pre-trained model
old_model_encoder = self.model.encoder.layer
new_model_encoder = nn.ModuleList()
for i in layers:
new_model_encoder.append(old_model_encoder[i])
self.model.encoder.layer = new_model_encoder
# Feed forward
def forward(self, **x):
return self.model(**x)['pooler_output']
# Complete model
class SLR_Classifier(nn.Module):
def __init__(self, **data):
super(SLR_Classifier, self).__init__()
# Dummy Parameter
self.dummy_param = nn.Parameter(torch.empty(0))
# Loss function
# Binary Cross Entropy with logits reduced to mean
self.loss_fn = nn.BCEWithLogitsLoss(reduction = 'mean',
pos_weight=torch.FloatTensor([data.get("pos_weight", 2.5)]))
# Pre-trained model
self.Encoder = Encoder(layers = data.get("bert_layers", range(12)),
freeze_bert = data.get("freeze_bert", False),
model = data.get("model"),
)
# Feature Map Layer
self.feature_map = nn.Sequential(
# nn.LayerNorm(self.Encoder.model.config.hidden_size),
nn.BatchNorm1d(self.Encoder.model.config.hidden_size),
# nn.Dropout(data.get("drop", 0.5)),
nn.Linear(self.Encoder.model.config.hidden_size, 200),
nn.Dropout(data.get("drop", 0.5)),
)
# Classifier Layer
self.classifier = nn.Sequential(
# nn.LayerNorm(self.Encoder.model.config.hidden_size),
# nn.Dropout(data.get("drop", 0.5)),
# nn.BatchNorm1d(self.Encoder.model.config.hidden_size),
# nn.Dropout(data.get("drop", 0.5)),
nn.Tanh(),
nn.Linear(200, 1)
)
# Initializing layer parameters
nn.init.normal_(self.feature_map[1].weight, mean=0, std=0.00001)
nn.init.zeros_(self.feature_map[1].bias)
# Feed forward
def forward(self, input_ids, attention_mask, token_type_ids, labels):
predict = self.Encoder(**{"input_ids":input_ids,
"attention_mask":attention_mask,
"token_type_ids":token_type_ids})
feature = self.feature_map(predict)
logit = self.classifier(feature)
predict = torch.sigmoid(logit)
# Loss function
loss = self.loss_fn(logit.to(torch.float), labels.to(torch.float).unsqueeze(1))
return [loss, [feature, logit], predict]
# Undesirable patterns within texts
patterns = {
'CONCLUSIONS AND IMPLICATIONS':'',
'BACKGROUND AND PURPOSE':'',
'EXPERIMENTAL APPROACH':'',
'KEY RESULTS AEA':'',
'©':'',
'®':'',
'μ':'',
'(C)':'',
'OBJECTIVE:':'',
'MATERIALS AND METHODS:':'',
'SIGNIFICANCE:':'',
'BACKGROUND:':'',
'RESULTS:':'',
'METHODS:':'',
'CONCLUSIONS:':'',
'AIM:':'',
'STUDY DESIGN:':'',
'CLINICAL RELEVANCE:':'',
'CONCLUSION:':'',
'HYPOTHESIS:':'',
'CLINICAL RELEVANCE:':'',
'Questions/Purposes:':'',
'Introduction:':'',
'PURPOSE:':'',
'PATIENTS AND METHODS:':'',
'FINDINGS:':'',
'INTERPRETATIONS:':'',
'FUNDING:':'',
'PROGRESS:':'',
'CONTEXT:':'',
'MEASURES:':'',
'DESIGN:':'',
'BACKGROUND AND OBJECTIVES:':'',
'
':'',
'
':'',
'<>':'',
'+/-':'',
'\(.+\)':'',
'\[.+\]':'',
' \d ':'',
'<':'',
'>':'',
'- ':'',
' +':' ',
', ,':',',
',,':',',
'%':' percent',
'per cent':' percent'
}
patterns = {x.lower():y for x,y in patterns.items()}
LABEL_MAP = {'negative': 0,
'not included':0,
'0':0,
0:0,
'excluded':0,
'positive': 1,
'included':1,
'1':1,
1:1,
}
class SLR_DataSet(Dataset):
def __init__(self,treat_text =None, **args):
self.tokenizer = args.get('tokenizer')
self.data = args.get('data')
self.max_seq_length = args.get("max_seq_length", 512)
self.INPUT_NAME = args.get("input", 'x')
self.LABEL_NAME = args.get("output", 'y')
self.treat_text = treat_text
# Tokenizing and processing text
def encode_text(self, example):
comment_text = example[self.INPUT_NAME]
if self.treat_text:
comment_text = self.treat_text(comment_text)
try:
labels = LABEL_MAP[example[self.LABEL_NAME].lower()]
except:
labels = -1
encoding = self.tokenizer.encode_plus(
(comment_text, "It is great text"),
add_special_tokens=True,
max_length=self.max_seq_length,
return_token_type_ids=True,
padding="max_length",
truncation=True,
return_attention_mask=True,
return_tensors='pt',
)
return tuple((
encoding["input_ids"].flatten(),
encoding["attention_mask"].flatten(),
encoding["token_type_ids"].flatten(),
torch.tensor([torch.tensor(labels).to(int)])
))
def __len__(self):
return len(self.data)
# Returning data
def __getitem__(self, index: int):
# print(index)
data_row = self.data.reset_index().iloc[index]
temp_data = self.encode_text(data_row)
return temp_data
class Learner(nn.Module):
def __init__(self, **args):
"""
:param args:
"""
super(Learner, self).__init__()
self.inner_print = args.get('inner_print')
self.inner_batch_size = args.get('inner_batch_size')
self.outer_update_lr = args.get('outer_update_lr')
self.inner_update_lr = args.get('inner_update_lr')
self.inner_update_step = args.get('inner_update_step')
self.inner_update_step_eval = args.get('inner_update_step_eval')
self.model = args.get('model')
self.device = args.get('device')
# Outer optimizer
self.outer_optimizer = Adam(self.model.parameters(), lr=self.outer_update_lr)
self.model.train()
def forward(self, batch_tasks, training = True, valid_train = True):
"""
batch = [(support TensorDataset, query TensorDataset),
(support TensorDataset, query TensorDataset),
(support TensorDataset, query TensorDataset),
(support TensorDataset, query TensorDataset)]
# support = TensorDataset(all_input_ids, all_attention_mask, all_segment_ids, all_label_ids)
"""
task_accs = []
task_f1 = []
task_recall = []
sum_gradients = []
num_task = len(batch_tasks)
num_inner_update_step = self.inner_update_step if training else self.inner_update_step_eval
# Outer loop tasks
for task_id, task in enumerate(batch_tasks):
support = task[0]
query = task[1]
name = task[2]
# Copying model
fast_model = deepcopy(self.model)
fast_model.to(self.device)
# Inner trainer optimizer
inner_optimizer = Adam(fast_model.parameters(), lr=self.inner_update_lr)
# Creating training data loaders
if len(support) % self.inner_batch_size == 1 :
support_dataloader = DataLoader(support, sampler=RandomSampler(support),
batch_size=self.inner_batch_size,
drop_last=True)
else:
support_dataloader = DataLoader(support, sampler=RandomSampler(support),
batch_size=self.inner_batch_size,
drop_last=False)
# steps_per_epoch=len(support) // self.inner_batch_size
# total_training_steps = steps_per_epoch * 5
# warmup_steps = total_training_steps // 3
#
# scheduler = get_linear_schedule_with_warmup(
# inner_optimizer,
# num_warmup_steps=warmup_steps,
# num_training_steps=total_training_steps
# )
fast_model.train()
# Inner loop training epoch (support set)
if valid_train:
print('----Task',task_id,":", name, '----')
for i in range(0, num_inner_update_step):
all_loss = []
# Inner loop training batch (support set)
for inner_step, batch in enumerate(support_dataloader):
batch = tuple(t.to(self.device) for t in batch)
input_ids, attention_mask, token_type_ids, label_id = batch
# Feed Foward
loss, _, _ = fast_model(input_ids, attention_mask, token_type_ids=token_type_ids, labels = label_id)
# Computing gradients
loss.backward()
# torch.nn.utils.clip_grad_norm_(fast_model.parameters(), max_norm=1)
# Updating inner training parameters
inner_optimizer.step()
inner_optimizer.zero_grad()
# Appending losses
all_loss.append(loss.item())
del batch, input_ids, attention_mask, label_id
torch.cuda.empty_cache()
if valid_train:
if (i+1) % self.inner_print == 0:
print("Inner Loss: ", np.mean(all_loss))
fast_model.to(torch.device('cpu'))
# Inner training phase weights
if training:
meta_weights = list(self.model.parameters())
fast_weights = list(fast_model.parameters())
# Appending gradients
gradients = []
for i, (meta_params, fast_params) in enumerate(zip(meta_weights, fast_weights)):
gradient = meta_params - fast_params
if task_id == 0:
sum_gradients.append(gradient)
else:
sum_gradients[i] += gradient
# Inner test (query set)
fast_model.to(self.device)
fast_model.eval()
if valid_train:
# Inner test (query set)
fast_model.to(self.device)
fast_model.eval()
with torch.no_grad():
# Data loader
query_dataloader = DataLoader(query, sampler=None, batch_size=len(query))
query_batch = iter(query_dataloader).next()
query_batch = tuple(t.to(self.device) for t in query_batch)
q_input_ids, q_attention_mask, q_token_type_ids, q_label_id = query_batch
# Feedfoward
_, _, pre_label_id = fast_model(q_input_ids, q_attention_mask, q_token_type_ids, labels = q_label_id)
# Predictions
pre_label_id = pre_label_id.detach().cpu().squeeze()
# Labels
q_label_id = q_label_id.detach().cpu()
# Calculating metrics
acc = fn.accuracy(pre_label_id, q_label_id).item()
recall = fn.recall(pre_label_id, q_label_id).item(),
f1 = fn.f1_score(pre_label_id, q_label_id).item()
# appending metrics
task_accs.append(acc)
task_f1.append(f1)
task_recall.append(recall)
fast_model.to(torch.device('cpu'))
del fast_model, inner_optimizer
torch.cuda.empty_cache()
print("\n")
print("f1:",np.mean(task_f1))
print("recall:",np.mean(task_recall))
# Updating outer training parameters
if training:
# Mean of gradients
for i in range(0,len(sum_gradients)):
sum_gradients[i] = sum_gradients[i] / float(num_task)
# Indexing parameters to model
for i, params in enumerate(self.model.parameters()):
params.grad = sum_gradients[i]
# Updating parameters
self.outer_optimizer.step()
self.outer_optimizer.zero_grad()
del sum_gradients
gc.collect()
torch.cuda.empty_cache()
if valid_train:
return np.mean(task_accs)
else:
return np.array(0)
# Creating Meta Tasks
class MetaTask(Dataset):
def __init__(self, examples, num_task, k_support, k_query,
tokenizer, training=True, max_seq_length=512,
treat_text =None, **args):
"""
:param samples: list of samples
:param num_task: number of training tasks.
:param k_support: number of classes support samples per task
:param k_query: number of classes query sample per task
"""
self.examples = examples
self.num_task = num_task
self.k_support = k_support
self.k_query = k_query
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
self.treat_text = treat_text
# Randomly generating tasks
self.create_batch(self.num_task, training)
# Creating batch
def create_batch(self, num_task, training):
self.supports = [] # support set
self.queries = [] # query set
self.task_names = [] # Name of task
self.supports_indexs = [] # index of supports
self.queries_indexs = [] # index of queries
self.num_task=num_task
# Available tasks
domains = self.examples['domain'].unique()
# If not training, create all tasks
if not(training):
self.task_names = domains
num_task = len(self.task_names)
self.num_task=num_task
for b in range(num_task): # For each task,
total_per_class = self.k_support + self.k_query
task_size = 2*self.k_support + 2*self.k_query
# Select a task at random
if training:
domain = random.choice(domains)
self.task_names.append(domain)
else:
domain = self.task_names[b]
# Task data
domainExamples = self.examples[self.examples['domain'] == domain]
# Minimal label quantity
min_per_class = min(domainExamples['label'].value_counts())
if total_per_class > min_per_class:
total_per_class = min_per_class
# Select k_support + k_query task examples
# Sample (n) from each label(class)
selected_examples = domainExamples.groupby("label").sample(total_per_class, replace = False)
# Split data into support (training) and query (testing) sets
s, q = train_test_split(selected_examples,
stratify= selected_examples["label"],
test_size= 2*self.k_query/task_size,
shuffle=True)
# Permutating data
s = s.sample(frac=1)
q = q.sample(frac=1)
# Appending indexes
if not(training):
self.supports_indexs.append(s.index)
self.queries_indexs.append(q.index)
# Creating list of support (training) and query (testing) tasks
self.supports.append(s.to_dict('records'))
self.queries.append(q.to_dict('records'))
# Creating task tensors
def create_feature_set(self, examples):
all_input_ids = torch.empty(len(examples), self.max_seq_length, dtype = torch.long)
all_attention_mask = torch.empty(len(examples), self.max_seq_length, dtype = torch.long)
all_token_type_ids = torch.empty(len(examples), self.max_seq_length, dtype = torch.long)
all_label_ids = torch.empty(len(examples), dtype = torch.long)
for _id, e in enumerate(examples):
all_input_ids[_id], all_attention_mask[_id], all_token_type_ids[_id], all_label_ids[_id] = self.encode_text(e)
return TensorDataset(
all_input_ids,
all_attention_mask,
all_token_type_ids,
all_label_ids
)
# Data encoding
def encode_text(self, example):
comment_text = example["text"]
if self.treat_text:
comment_text = self.treat_text(comment_text)
labels = LABEL_MAP[example["label"]]
encoding = self.tokenizer.encode_plus(
(comment_text, "It is a great text."),
add_special_tokens=True,
max_length=self.max_seq_length,
return_token_type_ids=True,
padding="max_length",
truncation=True,
return_attention_mask=True,
return_tensors='pt',
)
return tuple((
encoding["input_ids"].flatten(),
encoding["attention_mask"].flatten(),
encoding["token_type_ids"].flatten(),
torch.tensor([torch.tensor(labels).to(int)])
))
# Returns data upon calling
def __getitem__(self, index):
support_set = self.create_feature_set(self.supports[index])
query_set = self.create_feature_set(self.queries[index])
name = self.task_names[index]
return support_set, query_set, name
def __len__(self):
return self.num_task
class treat_text:
def __init__(self, patterns):
self.patterns = patterns
def __call__(self,text):
text = unicodedata.normalize("NFKD",str(text))
text = multiple_replace(self.patterns,text.lower())
text = re.sub('(\(.+\))|(\[.+\])|( \d )|(<)|(>)|(- )','', text)
text = re.sub('( +)',' ', text)
text = re.sub('(, ,)|(,,)',',', text)
text = re.sub('(%)|(per cent)',' percent', text)
return text
# Regex multiple replace function
def multiple_replace(dict, text):
# Building regex from dict keys
regex = re.compile("(%s)" % "|".join(map(re.escape, dict.keys())))
# Substitution
return regex.sub(lambda mo: dict[mo.string[mo.start():mo.end()]], text)