File size: 4,731 Bytes
8cb4f3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import re, os
import nltk
from nltk.corpus import wordnet
import dill as pickle
import pandas as pd
from torchtext import data
from laonlp import tokenize

def multiple_replace(dict, text):
  # Create a regular expression  from the dictionary keys
  regex = re.compile("(%s)" % "|".join(map(re.escape, dict.keys())))

  # For each match, look-up corresponding value in dictionary
  return regex.sub(lambda mo: dict[mo.string[mo.start():mo.end()]], text) 

# get_synonym  replace word with any synonym found among src
def get_synonym(word, SRC):
  syns = wordnet.synsets(word)
  for s in syns:
    for l in s.lemmas():
      if SRC.vocab.stoi[l.name()] != 0:
        return SRC.vocab.stoi[l.name()]
                
  return 0

class Tokenizer:
    def __init__(self, lang=None):
        if(lang is not None):
            self.nlp = spacy.load(lang)
            self.tokenizer_fn = self.nlp.tokenizer
        else:
            self.tokenizer_fn = lambda l: l.strip().split()
            
    # def tokenize(self, sentence):
    #     sentence = re.sub(
    #     r"[\*\"“”\n\\…\+\-\/\=\(\)‘•:\[\]\|’\!;]", " ", str(sentence))
    #     sentence = re.sub(r"[ ]+", " ", sentence)
    #     sentence = re.sub(r"\!+", "!", sentence)
    #     sentence = re.sub(r"\,+", ",", sentence)
    #     sentence = re.sub(r"\?+", "?", sentence)
    #     sentence = sentence.lower()
    #     return [tok.text for tok in self.tokenizer_fn(sentence) if tok.text != " "]


def read_data(src_file, trg_file):
    src_data = open(src_file).read().strip().split('\n')

    trg_data = open(trg_file).read().strip().split('\n')
 
    return src_data, trg_data


def read_file(file_dir):
    f = open(file_dir, 'r')
    data =  f.read().strip().split('\n')
    return data

def write_file(file_dir, content):
    f = open(file_dir, "w")
    f.write(content)
    f.close()

def create_fields(src_lang, trg_lang):
    
    #print("loading spacy tokenizers...")
    #
    # t_src = tokenize(src_lang)
    # t_trg = tokenize(trg_lang)
    # t_src_tokenizer = t_trg_tokenizer = lambda x: x.strip().split()
    target_tokenizer = lambda x: x.strip().split()

    TRG = data.Field(lower=True, tokenize=target_tokenizer, init_token='<sos>', eos_token='<eos>')
    SRC = data.Field(lower=True, tokenize=tokenize.word_tokenize)
        
    return SRC, TRG

def create_dataset(src_data, trg_data, max_strlen, batchsize, device, SRC, TRG, istrain=True):

    print("creating dataset and iterator... ")

    raw_data = {'src' : [line for line in src_data], 'trg': [line for line in trg_data]}
    df = pd.DataFrame(raw_data, columns=["src", "trg"])
    
    mask = (df['src'].str.count(' ') < max_strlen) & (df['trg'].str.count(' ') < max_strlen)
    df = df.loc[mask]

    df.to_csv("translate_transformer_temp.csv", index=False)
    
    data_fields = [('src', SRC), ('trg', TRG)]
    train = data.TabularDataset('./translate_transformer_temp.csv', format='csv', fields=data_fields)

    train_iter = MyIterator(train, batch_size=batchsize, device=device,
                        repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
                        batch_size_fn=batch_size_fn, train=istrain, shuffle=True)
    
    os.remove('translate_transformer_temp.csv')
    
    if istrain:
        SRC.build_vocab(train)
        TRG.build_vocab(train)

    return train_iter

class MyIterator(data.Iterator):
    def create_batches(self):
        if self.train:
            def pool(d, random_shuffler):
                for p in data.batch(d, self.batch_size * 100):
                    p_batch = data.batch(
                        sorted(p, key=self.sort_key),
                        self.batch_size, self.batch_size_fn)
                    for b in random_shuffler(list(p_batch)):
                        yield b
            self.batches = pool(self.data(), self.random_shuffler)
            
        else:
            self.batches = []
            for b in data.batch(self.data(), self.batch_size,
                                          self.batch_size_fn):
                self.batches.append(sorted(b, key=self.sort_key))

global max_src_in_batch, max_tgt_in_batch

def batch_size_fn(new, count, sofar):
    "Keep augmenting batch and calculate total number of tokens + padding."
    global max_src_in_batch, max_tgt_in_batch
    if count == 1:
        max_src_in_batch = 0
        max_tgt_in_batch = 0
    max_src_in_batch = max(max_src_in_batch,  len(new.src))
    max_tgt_in_batch = max(max_tgt_in_batch,  len(new.trg) + 2)
    src_elements = count * max_src_in_batch
    tgt_elements = count * max_tgt_in_batch
    return max(src_elements, tgt_elements)

def generate_language_token(lang: str):
    return '<{}>'.format(lang.strip())