Spaces:
Sleeping
Sleeping
import sys | |
sys.path.append("..") | |
from dataset.vocab import Vocab | |
from torch.nn.utils.rnn import pad_sequence | |
from transformers import AutoTokenizer | |
class TokenAligner(): | |
def __init__(self, tokenizer: AutoTokenizer, vocab: Vocab): | |
self.tokenizer = tokenizer | |
self.vocab = vocab | |
""" | |
params: | |
text ---- str | |
""" | |
def _char_tokenize(self, text): | |
characters = list(text) | |
tokens = [ token + "@@" if i < len(characters) - 1 and characters[i + 1] != " " else token for i, token in enumerate(characters)] | |
tokens = [token for token in tokens if token not in [" @@", " "]] | |
encoded = self.tokenizer.encode_plus(tokens, return_tensors = "pt") | |
token_ids = encoded['input_ids'].squeeze(0) | |
attn_mask = encoded['attention_mask'].squeeze(0) | |
return tokens, token_ids, attn_mask | |
def char_tokenize(self, batch_texts): | |
doc = dict() | |
doc['tokens'] = [] | |
doc['token_ids'] = [] | |
doc['attention_mask'] = [] | |
for text in batch_texts: | |
tokens, token_ids, attn_mask = self._char_tokenize(text) | |
doc['tokens'].append(tokens) | |
doc['token_ids'].append(token_ids) | |
doc['attention_mask'].append(attn_mask) | |
return doc | |
def tokenize_for_transformer_with_tokenization(self, batch_noised_text, batch_label_texts = None): | |
docs = self.char_tokenize(batch_noised_text) | |
batch_srcs = docs['token_ids'] | |
batch_attention_masks = docs['attention_mask'] | |
batch_attention_masks = pad_sequence(batch_attention_masks , | |
batch_first=True, padding_value=0) | |
batch_srcs = pad_sequence(batch_srcs , | |
batch_first=True, padding_value=self.tokenizer.pad_token_id) | |
if batch_label_texts != None: | |
batch_lengths = [len(self.tokenizer.tokenize(text)) for text in batch_label_texts] | |
batch_tgts = self.tokenizer.batch_encode_plus(batch_label_texts, max_length = 512, | |
truncation = True, padding=True, return_tensors="pt")['input_ids'] | |
return batch_srcs, batch_tgts, batch_lengths, batch_attention_masks | |
return batch_srcs, batch_attention_masks |