NMT-LaVi / utils /data.py
hieungo1410's picture
'add'
8cb4f3b
raw
history blame
4.73 kB
import re, os
import nltk
from nltk.corpus import wordnet
import dill as pickle
import pandas as pd
from torchtext import data
from laonlp import tokenize
def multiple_replace(dict, text):
# Create a regular expression from the dictionary keys
regex = re.compile("(%s)" % "|".join(map(re.escape, dict.keys())))
# For each match, look-up corresponding value in dictionary
return regex.sub(lambda mo: dict[mo.string[mo.start():mo.end()]], text)
# get_synonym replace word with any synonym found among src
def get_synonym(word, SRC):
syns = wordnet.synsets(word)
for s in syns:
for l in s.lemmas():
if SRC.vocab.stoi[l.name()] != 0:
return SRC.vocab.stoi[l.name()]
return 0
class Tokenizer:
def __init__(self, lang=None):
if(lang is not None):
self.nlp = spacy.load(lang)
self.tokenizer_fn = self.nlp.tokenizer
else:
self.tokenizer_fn = lambda l: l.strip().split()
# def tokenize(self, sentence):
# sentence = re.sub(
# r"[\*\"β€œβ€\n\\…\+\-\/\=\(\)β€˜β€’:\[\]\|’\!;]", " ", str(sentence))
# sentence = re.sub(r"[ ]+", " ", sentence)
# sentence = re.sub(r"\!+", "!", sentence)
# sentence = re.sub(r"\,+", ",", sentence)
# sentence = re.sub(r"\?+", "?", sentence)
# sentence = sentence.lower()
# return [tok.text for tok in self.tokenizer_fn(sentence) if tok.text != " "]
def read_data(src_file, trg_file):
src_data = open(src_file).read().strip().split('\n')
trg_data = open(trg_file).read().strip().split('\n')
return src_data, trg_data
def read_file(file_dir):
f = open(file_dir, 'r')
data = f.read().strip().split('\n')
return data
def write_file(file_dir, content):
f = open(file_dir, "w")
f.write(content)
f.close()
def create_fields(src_lang, trg_lang):
#print("loading spacy tokenizers...")
#
# t_src = tokenize(src_lang)
# t_trg = tokenize(trg_lang)
# t_src_tokenizer = t_trg_tokenizer = lambda x: x.strip().split()
target_tokenizer = lambda x: x.strip().split()
TRG = data.Field(lower=True, tokenize=target_tokenizer, init_token='<sos>', eos_token='<eos>')
SRC = data.Field(lower=True, tokenize=tokenize.word_tokenize)
return SRC, TRG
def create_dataset(src_data, trg_data, max_strlen, batchsize, device, SRC, TRG, istrain=True):
print("creating dataset and iterator... ")
raw_data = {'src' : [line for line in src_data], 'trg': [line for line in trg_data]}
df = pd.DataFrame(raw_data, columns=["src", "trg"])
mask = (df['src'].str.count(' ') < max_strlen) & (df['trg'].str.count(' ') < max_strlen)
df = df.loc[mask]
df.to_csv("translate_transformer_temp.csv", index=False)
data_fields = [('src', SRC), ('trg', TRG)]
train = data.TabularDataset('./translate_transformer_temp.csv', format='csv', fields=data_fields)
train_iter = MyIterator(train, batch_size=batchsize, device=device,
repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
batch_size_fn=batch_size_fn, train=istrain, shuffle=True)
os.remove('translate_transformer_temp.csv')
if istrain:
SRC.build_vocab(train)
TRG.build_vocab(train)
return train_iter
class MyIterator(data.Iterator):
def create_batches(self):
if self.train:
def pool(d, random_shuffler):
for p in data.batch(d, self.batch_size * 100):
p_batch = data.batch(
sorted(p, key=self.sort_key),
self.batch_size, self.batch_size_fn)
for b in random_shuffler(list(p_batch)):
yield b
self.batches = pool(self.data(), self.random_shuffler)
else:
self.batches = []
for b in data.batch(self.data(), self.batch_size,
self.batch_size_fn):
self.batches.append(sorted(b, key=self.sort_key))
global max_src_in_batch, max_tgt_in_batch
def batch_size_fn(new, count, sofar):
"Keep augmenting batch and calculate total number of tokens + padding."
global max_src_in_batch, max_tgt_in_batch
if count == 1:
max_src_in_batch = 0
max_tgt_in_batch = 0
max_src_in_batch = max(max_src_in_batch, len(new.src))
max_tgt_in_batch = max(max_tgt_in_batch, len(new.trg) + 2)
src_elements = count * max_src_in_batch
tgt_elements = count * max_tgt_in_batch
return max(src_elements, tgt_elements)
def generate_language_token(lang: str):
return '<{}>'.format(lang.strip())