# coding=utf-8 # Copyright 2021 The IDEA Authors. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from logging import basicConfig, setLogRecordFactory import torch from torch import nn import json from tqdm import tqdm import os import numpy as np from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, BertTokenizer, file_utils ) import pytorch_lightning as pl from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning import trainer, loggers from torch.utils.data import Dataset, DataLoader from transformers.optimization import get_linear_schedule_with_warmup from transformers import BertForPreTraining, BertForMaskedLM, BertModel from transformers import BertConfig, BertForTokenClassification, BertPreTrainedModel import transformers import unicodedata import re import argparse transformers.logging.set_verbosity_error() # os.environ["CUDA_VISIBLE_DEVICES"] = '6' def search(pattern, sequence): n = len(pattern) res = [] for i in range(len(sequence)): if sequence[i:i + n] == pattern: res.append([i, i + n-1]) return res class UbertDataset(Dataset): def __init__(self, data, tokenizer, args, used_mask=True): super().__init__() self.tokenizer = tokenizer self.max_length = args.max_length self.num_labels = args.num_labels self.used_mask = used_mask self.data = data self.args = args def __len__(self): return len(self.data) def __getitem__(self, index): return self.encode(self.data[index], self.used_mask) def encode(self, item, used_mask=False): input_ids1 = [] attention_mask1 = [] token_type_ids1 = [] span_labels1 = [] span_labels_masks1 = [] input_ids0 = [] attention_mask0 = [] token_type_ids0 = [] span_labels0 = [] span_labels_masks0 = [] subtask_type = item['subtask_type'] for choice in item['choices']: try: texta = item['task_type'] + '[SEP]' + \ subtask_type + '[SEP]' + choice['entity_type'] textb = item['text'] encode_dict = self.tokenizer.encode_plus(texta, textb, max_length=self.max_length, padding='max_length', truncation='longest_first') encode_sent = encode_dict['input_ids'] encode_token_type_ids = encode_dict['token_type_ids'] encode_attention_mask = encode_dict['attention_mask'] span_label = np.zeros((self.max_length, self.max_length)) span_label_mask = np.zeros( (self.max_length, self.max_length))-10000 if item['task_type'] == '分类任务': span_label_mask[0, 0] = 0 span_label[0, 0] = choice['label'] else: question_len = len(self.tokenizer.encode(texta)) span_label_mask[question_len:, question_len:] = np.zeros( (self.max_length-question_len, self.max_length-question_len)) for entity in choice['entity_list']: # if 'entity_name' in entity.keys() and entity['entity_name']=='': # continue entity_idx_list = entity['entity_idx'] if entity_idx_list == []: continue for entity_idx in entity_idx_list: if entity_idx == []: continue start_idx_text = item['text'][:entity_idx[0]] start_idx_text_encode = self.tokenizer.encode( start_idx_text, add_special_tokens=False) start_idx = question_len + \ len(start_idx_text_encode) end_idx_text = item['text'][:entity_idx[1]+1] end_idx_text_encode = self.tokenizer.encode( end_idx_text, add_special_tokens=False) end_idx = question_len + \ len(end_idx_text_encode) - 1 if start_idx < self.max_length and end_idx < self.max_length: span_label[start_idx, end_idx] = 1 if np.sum(span_label) < 1: input_ids0.append(encode_sent) attention_mask0.append(encode_attention_mask) token_type_ids0.append(encode_token_type_ids) span_labels0.append(span_label) span_labels_masks0.append(span_label_mask) else: input_ids1.append(encode_sent) attention_mask1.append(encode_attention_mask) token_type_ids1.append(encode_token_type_ids) span_labels1.append(span_label) span_labels_masks1.append(span_label_mask) except: print(item) print(texta) print(textb) randomize = np.arange(len(input_ids0)) np.random.shuffle(randomize) cur = 0 count = len(input_ids1) while count < self.args.num_labels: if cur < len(randomize): input_ids1.append(input_ids0[randomize[cur]]) attention_mask1.append(attention_mask0[randomize[cur]]) token_type_ids1.append(token_type_ids0[randomize[cur]]) span_labels1.append(span_labels0[randomize[cur]]) span_labels_masks1.append(span_labels_masks0[randomize[cur]]) cur += 1 count += 1 while len(input_ids1) < self.args.num_labels: input_ids1.append([0]*self.max_length) attention_mask1.append([0]*self.max_length) token_type_ids1.append([0]*self.max_length) span_labels1.append(np.zeros((self.max_length, self.max_length))) span_labels_masks1.append( np.zeros((self.max_length, self.max_length))-10000) input_ids = input_ids1[:self.args.num_labels] attention_mask = attention_mask1[:self.args.num_labels] token_type_ids = token_type_ids1[:self.args.num_labels] span_labels = span_labels1[:self.args.num_labels] span_labels_masks = span_labels_masks1[:self.args.num_labels] span_labels = np.array(span_labels) span_labels_masks = np.array(span_labels_masks) if np.sum(span_labels) < 1: span_labels[-1, -1, -1] = 1 span_labels_masks[-1, -1, -1] = 10000 sample = { "input_ids": torch.tensor(input_ids).long(), "token_type_ids": torch.tensor(token_type_ids).long(), "attention_mask": torch.tensor(attention_mask).float(), "span_labels": torch.tensor(span_labels).float(), "span_labels_mask": torch.tensor(span_labels_masks).float() } return sample class UbertDataModel(pl.LightningDataModule): @staticmethod def add_data_specific_args(parent_args): parser = parent_args.add_argument_group('TASK NAME DataModel') parser.add_argument('--num_workers', default=8, type=int) parser.add_argument('--batchsize', default=8, type=int) parser.add_argument('--max_length', default=128, type=int) return parent_args def __init__(self, train_data, val_data, tokenizer, args): super().__init__() self.batchsize = args.batchsize self.train_data = UbertDataset(train_data, tokenizer, args, True) self.valid_data = UbertDataset(val_data, tokenizer, args, False) def train_dataloader(self): return DataLoader(self.train_data, shuffle=True, batch_size=self.batchsize, pin_memory=False) def val_dataloader(self): return DataLoader(self.valid_data, shuffle=False, batch_size=self.batchsize, pin_memory=False) class biaffine(nn.Module): def __init__(self, in_size, out_size, bias_x=True, bias_y=True): super().__init__() self.bias_x = bias_x self.bias_y = bias_y self.out_size = out_size self.U = torch.nn.Parameter(torch.zeros( in_size + int(bias_x), out_size, in_size + int(bias_y))) torch.nn.init.normal_(self.U, mean=0, std=0.1) def forward(self, x, y): if self.bias_x: x = torch.cat((x, torch.ones_like(x[..., :1])), dim=-1) if self.bias_y: y = torch.cat((y, torch.ones_like(y[..., :1])), dim=-1) bilinar_mapping = torch.einsum('bxi,ioj,byj->bxyo', x, self.U, y) return bilinar_mapping class MultilabelCrossEntropy(nn.Module): def __init__(self): super().__init__() def forward(self, y_pred, y_true): y_true = y_true.float() y_pred = torch.mul((1.0 - torch.mul(y_true, 2.0)), y_pred) y_pred_neg = y_pred - torch.mul(y_true, 1e12) y_pred_pos = y_pred - torch.mul(1.0 - y_true, 1e12) zeros = torch.zeros_like(y_pred[..., :1]) y_pred_neg = torch.cat([y_pred_neg, zeros], axis=-1) y_pred_pos = torch.cat([y_pred_pos, zeros], axis=-1) neg_loss = torch.logsumexp(y_pred_neg, axis=-1) pos_loss = torch.logsumexp(y_pred_pos, axis=-1) loss = torch.mean(neg_loss + pos_loss) return loss class UbertModel(BertPreTrainedModel): def __init__(self, config): super().__init__(config) self.bert = BertModel(config) self.query_layer = torch.nn.Sequential(torch.nn.Linear(in_features=self.config.hidden_size, out_features=self.config.biaffine_size), torch.nn.GELU()) self.key_layer = torch.nn.Sequential(torch.nn.Linear(in_features=self.config.hidden_size, out_features=self.config.biaffine_size), torch.nn.GELU()) self.biaffine_query_key_cls = biaffine(self.config.biaffine_size, 1) self.loss_softmax = MultilabelCrossEntropy() self.loss_sigmoid = torch.nn.BCEWithLogitsLoss(reduction='mean') def forward(self, input_ids, attention_mask, token_type_ids, span_labels=None, span_labels_mask=None): batch_size, num_label, seq_len = input_ids.shape input_ids = input_ids.view(-1, seq_len) attention_mask = attention_mask.view(-1, seq_len) token_type_ids = token_type_ids.view(-1, seq_len) batch_size, seq_len = input_ids.shape outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True) # (bsz, seq, dim) hidden_states = outputs[0] batch_size, seq_len, hidden_size = hidden_states.shape query = self.query_layer(hidden_states) key = self.key_layer(hidden_states) span_logits = self.biaffine_query_key_cls( query, key).reshape(-1, num_label, seq_len, seq_len) span_logits = span_logits + span_labels_mask if span_labels == None: return 0, span_logits else: soft_loss1 = self.loss_softmax( span_logits.reshape(-1, num_label, seq_len*seq_len), span_labels.reshape(-1, num_label, seq_len*seq_len)) soft_loss2 = self.loss_softmax(span_logits.permute( 0, 2, 3, 1), span_labels.permute(0, 2, 3, 1)) sig_loss = self.loss_sigmoid(span_logits, span_labels) all_loss = 10*(100*sig_loss+soft_loss1+soft_loss2) return all_loss, span_logits class UbertLitModel(pl.LightningModule): @staticmethod def add_model_specific_args(parent_args): parser = parent_args.add_argument_group('BaseModel') parser.add_argument('--learning_rate', default=1e-5, type=float) parser.add_argument('--weight_decay', default=0.1, type=float) parser.add_argument('--warmup', default=0.01, type=float) parser.add_argument('--num_labels', default=10, type=int) return parent_args def __init__(self, args, num_data=1): super().__init__() self.args = args self.num_data = num_data self.model = UbertModel.from_pretrained( self.args.pretrained_model_path) self.count = 0 def setup(self, stage) -> None: if stage == 'fit': num_gpus = self.trainer.gpus if self.trainer.gpus is not None else 0 self.total_step = int(self.trainer.max_epochs * self.num_data / (max(1, num_gpus) * self.trainer.accumulate_grad_batches)) print('Total training step:', self.total_step) def training_step(self, batch, batch_idx): loss, span_logits = self.model(**batch) span_acc, recall, precise = self.comput_metrix_span( span_logits, batch['span_labels']) self.log('train_loss', loss) self.log('train_span_acc', span_acc) self.log('train_span_recall', recall) self.log('train_span_precise', precise) return loss def validation_step(self, batch, batch_idx): loss, span_logits = self.model(**batch) span_acc, recall, precise = self.comput_metrix_span( span_logits, batch['span_labels']) self.log('val_loss', loss) self.log('val_span_acc', span_acc) self.log('val_span_recall', recall) self.log('val_span_precise', precise) def predict_step(self, batch, batch_idx): loss, span_logits = self.model(**batch) span_acc = self.comput_metrix_span(span_logits, batch['span_labels']) return span_acc.item() def configure_optimizers(self): no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] paras = list( filter(lambda p: p[1].requires_grad, self.named_parameters())) paras = [{ 'params': [p for n, p in paras if not any(nd in n for nd in no_decay)], 'weight_decay': self.args.weight_decay }, { 'params': [p for n, p in paras if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] optimizer = torch.optim.AdamW(paras, lr=self.args.learning_rate) scheduler = get_linear_schedule_with_warmup( optimizer, int(self.total_step * self.args.warmup), self.total_step) return [{ 'optimizer': optimizer, 'lr_scheduler': { 'scheduler': scheduler, 'interval': 'step', 'frequency': 1 } }] def comput_metrix_span(self, logits, labels): ones = torch.ones_like(logits) zero = torch.zeros_like(logits) logits = torch.where(logits < 0, zero, ones) y_pred = logits.view(size=(-1,)) y_true = labels.view(size=(-1,)) corr = torch.eq(y_pred, y_true).float() corr = torch.multiply(y_true, corr) recall = torch.sum(corr.float())/(torch.sum(y_true.float())+1e-5) precise = torch.sum(corr.float())/(torch.sum(y_pred.float())+1e-5) f1 = 2*recall*precise/(recall+precise+1e-5) return f1, recall, precise class TaskModelCheckpoint: @staticmethod def add_argparse_args(parent_args): parser = parent_args.add_argument_group('BaseModel') parser.add_argument('--monitor', default='train_loss', type=str) parser.add_argument('--mode', default='min', type=str) parser.add_argument('--checkpoint_path', default='./checkpoint/', type=str) parser.add_argument( '--filename', default='model-{epoch:02d}-{train_loss:.4f}', type=str) parser.add_argument('--save_top_k', default=3, type=float) parser.add_argument('--every_n_epochs', default=1, type=float) parser.add_argument('--every_n_train_steps', default=100, type=float) parser.add_argument('--save_weights_only', default=True, type=bool) return parent_args def __init__(self, args): self.callbacks = ModelCheckpoint(monitor=args.monitor, save_top_k=args.save_top_k, mode=args.mode, save_last=True, every_n_train_steps=args.every_n_train_steps, save_weights_only=args.save_weights_only, dirpath=args.checkpoint_path, filename=args.filename) class OffsetMapping: def __init__(self): self._do_lower_case = True @staticmethod def stem(token): if token[:2] == '##': return token[2:] else: return token @staticmethod def _is_control(ch): return unicodedata.category(ch) in ('Cc', 'Cf') @staticmethod def _is_special(ch): return bool(ch) and (ch[0] == '[') and (ch[-1] == ']') def rematch(self, text, tokens): if self._do_lower_case: text = text.lower() normalized_text, char_mapping = '', [] for i, ch in enumerate(text): if self._do_lower_case: ch = unicodedata.normalize('NFD', ch) ch = ''.join( [c for c in ch if unicodedata.category(c) != 'Mn']) ch = ''.join([ c for c in ch if not (ord(c) == 0 or ord(c) == 0xfffd or self._is_control(c)) ]) normalized_text += ch char_mapping.extend([i] * len(ch)) text, token_mapping, offset = normalized_text, [], 0 for token in tokens: if self._is_special(token): token_mapping.append([offset]) offset += 1 else: token = self.stem(token) start = text[offset:].index(token) + offset end = start + len(token) token_mapping.append(char_mapping[start:end]) offset = end return token_mapping class extractModel: def get_actual_id(self, text, query_text, tokenizer, args): text_encode = tokenizer.encode(text) one_input_encode = tokenizer.encode(query_text) text_start_id = search(text_encode[1:-1], one_input_encode)[0][0] text_end_id = text_start_id+len(text_encode)-1 if text_end_id > args.max_length: text_end_id = args.max_length text_token = tokenizer.tokenize(text) text_mapping = OffsetMapping().rematch(text, text_token) return text_start_id, text_end_id, text_mapping, one_input_encode def extract_index(self, span_logits, sample_length, split_value=0.5): result = [] for i in range(sample_length): for j in range(i, sample_length): if span_logits[i, j] > split_value: result.append((i, j, span_logits[i, j])) return result def extract_entity(self, text, entity_idx, text_start_id, text_mapping): start_split = text_mapping[entity_idx[0]-text_start_id] if entity_idx[0] - \ text_start_id < len(text_mapping) and entity_idx[0]-text_start_id >= 0 else [] end_split = text_mapping[entity_idx[1]-text_start_id] if entity_idx[1] - \ text_start_id < len(text_mapping) and entity_idx[1]-text_start_id >= 0 else [] entity = '' if start_split != [] and end_split != []: entity = text[start_split[0]:end_split[-1]+1] return entity def extract(self, batch_data, model, tokenizer, args): input_ids = [] attention_mask = [] token_type_ids = [] span_labels_masks = [] for item in batch_data: input_ids0 = [] attention_mask0 = [] token_type_ids0 = [] span_labels_masks0 = [] for choice in item['choices']: texta = item['task_type'] + '[SEP]' + \ item['subtask_type'] + '[SEP]' + choice['entity_type'] textb = item['text'] encode_dict = tokenizer.encode_plus(texta, textb, max_length=args.max_length, padding='max_length', truncation='longest_first') encode_sent = encode_dict['input_ids'] encode_token_type_ids = encode_dict['token_type_ids'] encode_attention_mask = encode_dict['attention_mask'] span_label_mask = np.zeros( (args.max_length, args.max_length))-10000 if item['task_type'] == '分类任务': span_label_mask[0, 0] = 0 else: question_len = len(tokenizer.encode(texta)) span_label_mask[question_len:, question_len:] = np.zeros( (args.max_length-question_len, args.max_length-question_len)) input_ids0.append(encode_sent) attention_mask0.append(encode_attention_mask) token_type_ids0.append(encode_token_type_ids) span_labels_masks0.append(span_label_mask) input_ids.append(input_ids0) attention_mask.append(attention_mask0) token_type_ids.append(token_type_ids0) span_labels_masks.append(span_labels_masks0) input_ids = torch.tensor(input_ids).to(model.device) attention_mask = torch.tensor(attention_mask).to(model.device) token_type_ids = torch.tensor(token_type_ids).to(model.device) span_labels_mask = torch.tensor(span_labels_masks).to(model.device) _, span_logits = model.model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, span_labels=None, span_labels_mask=span_labels_mask) span_logits = torch.nn.functional.sigmoid(span_logits) span_logits = span_logits.cpu().detach().numpy() for i, item in enumerate(batch_data): if item['task_type'] == '分类任务': cls_idx = 0 max_c = np.argmax(span_logits[i, :, cls_idx, cls_idx]) batch_data[i]['choices'][max_c]['label'] = 1 batch_data[i]['choices'][max_c]['score'] = span_logits[i, max_c, cls_idx, cls_idx] else: if item['subtask_type'] == '抽取式阅读理解': for c in range(len(item['choices'])): texta = item['subtask_type'] + \ '[SEP]' + choice['entity_type'] textb = item['text'] text_start_id, text_end_id, offset_mapping, input_ids = self.get_actual_id( item['text'], texta+'[SEP]'+textb, tokenizer, args) logits = span_logits[i, c, :, :] max_index = np.unravel_index( np.argmax(logits, axis=None), logits.shape) entity_list = [] if logits[max_index] > args.threshold: entity = self.extract_entity( item['text'], (max_index[0], max_index[1]), text_start_id, offset_mapping) entity = { 'entity_name': entity, 'score': logits[max_index] } if entity not in entity_list: entity_list.append(entity) batch_data[i]['choices'][c]['entity_list'] = entity_list else: for c in range(len(item['choices'])): texta = item['task_type'] + '[SEP]' + item['subtask_type'] + \ '[SEP]' + item['choices'][c]['entity_type'] textb = item['text'] text_start_id, text_end_id, offset_mapping, input_ids = self.get_actual_id( item['text'], texta+'[SEP]'+textb, tokenizer, args) logits = span_logits[i, c, :, :] sample_length = len(input_ids) entity_idx_type_list = self.extract_index( logits, sample_length, split_value=args.threshold) entity_list = [] for entity_idx in entity_idx_type_list: entity = self.extract_entity( item['text'], (entity_idx[0], entity_idx[1]), text_start_id, offset_mapping) entity = { 'entity_name': entity, 'score': entity_idx[2] } if entity not in entity_list: entity_list.append(entity) batch_data[i]['choices'][c]['entity_list'] = entity_list return batch_data class UbertPiplines: @staticmethod def piplines_args(parent_args): total_parser = parent_args.add_argument_group("piplines args") total_parser.add_argument( '--pretrained_model_path', default='IDEA-CCNL/Erlangshen-Ubert-110M-Chinese', type=str) total_parser.add_argument('--output_save_path', default='./predict.json', type=str) total_parser.add_argument('--load_checkpoints_path', default='', type=str) total_parser.add_argument('--max_extract_entity_number', default=1, type=float) total_parser.add_argument('--train', action='store_true') total_parser.add_argument('--threshold', default=0.5, type=float) total_parser = UbertDataModel.add_data_specific_args(total_parser) total_parser = TaskModelCheckpoint.add_argparse_args(total_parser) total_parser = UbertLitModel.add_model_specific_args(total_parser) total_parser = pl.Trainer.add_argparse_args(parent_args) return parent_args def __init__(self, args): if args.load_checkpoints_path != '': self.model = UbertLitModel.load_from_checkpoint( args.load_checkpoints_path, args=args) else: self.model = UbertLitModel(args) self.args = args self.checkpoint_callback = TaskModelCheckpoint(args).callbacks self.logger = loggers.TensorBoardLogger(save_dir=args.default_root_dir) self.trainer = pl.Trainer.from_argparse_args(args, logger=self.logger, callbacks=[self.checkpoint_callback]) self.tokenizer = BertTokenizer.from_pretrained(args.pretrained_model_path, additional_special_tokens=['[unused'+str(i+1)+']' for i in range(99)]) self.em = extractModel() def fit(self, train_data, dev_data): data_model = UbertDataModel( train_data, dev_data, self.tokenizer, self.args) self.model.num_data = len(train_data) self.trainer.fit(self.model, data_model) def predict(self, test_data, cuda=True): result = [] start = 0 if cuda: self.model = self.model.cuda() self.model.eval() while start < len(test_data): batch_data = test_data[start:start+self.args.batchsize] start += self.args.batchsize batch_result = self.em.extract( batch_data, self.model, self.tokenizer, self.args) result.extend(batch_result) return result