Spaces:
No application file
No application file
File size: 4,731 Bytes
8cb4f3b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import re, os
import nltk
from nltk.corpus import wordnet
import dill as pickle
import pandas as pd
from torchtext import data
from laonlp import tokenize
def multiple_replace(dict, text):
# Create a regular expression from the dictionary keys
regex = re.compile("(%s)" % "|".join(map(re.escape, dict.keys())))
# For each match, look-up corresponding value in dictionary
return regex.sub(lambda mo: dict[mo.string[mo.start():mo.end()]], text)
# get_synonym replace word with any synonym found among src
def get_synonym(word, SRC):
syns = wordnet.synsets(word)
for s in syns:
for l in s.lemmas():
if SRC.vocab.stoi[l.name()] != 0:
return SRC.vocab.stoi[l.name()]
return 0
class Tokenizer:
def __init__(self, lang=None):
if(lang is not None):
self.nlp = spacy.load(lang)
self.tokenizer_fn = self.nlp.tokenizer
else:
self.tokenizer_fn = lambda l: l.strip().split()
# def tokenize(self, sentence):
# sentence = re.sub(
# r"[\*\"“”\n\\…\+\-\/\=\(\)‘•:\[\]\|’\!;]", " ", str(sentence))
# sentence = re.sub(r"[ ]+", " ", sentence)
# sentence = re.sub(r"\!+", "!", sentence)
# sentence = re.sub(r"\,+", ",", sentence)
# sentence = re.sub(r"\?+", "?", sentence)
# sentence = sentence.lower()
# return [tok.text for tok in self.tokenizer_fn(sentence) if tok.text != " "]
def read_data(src_file, trg_file):
src_data = open(src_file).read().strip().split('\n')
trg_data = open(trg_file).read().strip().split('\n')
return src_data, trg_data
def read_file(file_dir):
f = open(file_dir, 'r')
data = f.read().strip().split('\n')
return data
def write_file(file_dir, content):
f = open(file_dir, "w")
f.write(content)
f.close()
def create_fields(src_lang, trg_lang):
#print("loading spacy tokenizers...")
#
# t_src = tokenize(src_lang)
# t_trg = tokenize(trg_lang)
# t_src_tokenizer = t_trg_tokenizer = lambda x: x.strip().split()
target_tokenizer = lambda x: x.strip().split()
TRG = data.Field(lower=True, tokenize=target_tokenizer, init_token='<sos>', eos_token='<eos>')
SRC = data.Field(lower=True, tokenize=tokenize.word_tokenize)
return SRC, TRG
def create_dataset(src_data, trg_data, max_strlen, batchsize, device, SRC, TRG, istrain=True):
print("creating dataset and iterator... ")
raw_data = {'src' : [line for line in src_data], 'trg': [line for line in trg_data]}
df = pd.DataFrame(raw_data, columns=["src", "trg"])
mask = (df['src'].str.count(' ') < max_strlen) & (df['trg'].str.count(' ') < max_strlen)
df = df.loc[mask]
df.to_csv("translate_transformer_temp.csv", index=False)
data_fields = [('src', SRC), ('trg', TRG)]
train = data.TabularDataset('./translate_transformer_temp.csv', format='csv', fields=data_fields)
train_iter = MyIterator(train, batch_size=batchsize, device=device,
repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
batch_size_fn=batch_size_fn, train=istrain, shuffle=True)
os.remove('translate_transformer_temp.csv')
if istrain:
SRC.build_vocab(train)
TRG.build_vocab(train)
return train_iter
class MyIterator(data.Iterator):
def create_batches(self):
if self.train:
def pool(d, random_shuffler):
for p in data.batch(d, self.batch_size * 100):
p_batch = data.batch(
sorted(p, key=self.sort_key),
self.batch_size, self.batch_size_fn)
for b in random_shuffler(list(p_batch)):
yield b
self.batches = pool(self.data(), self.random_shuffler)
else:
self.batches = []
for b in data.batch(self.data(), self.batch_size,
self.batch_size_fn):
self.batches.append(sorted(b, key=self.sort_key))
global max_src_in_batch, max_tgt_in_batch
def batch_size_fn(new, count, sofar):
"Keep augmenting batch and calculate total number of tokens + padding."
global max_src_in_batch, max_tgt_in_batch
if count == 1:
max_src_in_batch = 0
max_tgt_in_batch = 0
max_src_in_batch = max(max_src_in_batch, len(new.src))
max_tgt_in_batch = max(max_tgt_in_batch, len(new.trg) + 2)
src_elements = count * max_src_in_batch
tgt_elements = count * max_tgt_in_batch
return max(src_elements, tgt_elements)
def generate_language_token(lang: str):
return '<{}>'.format(lang.strip()) |