arabic-auto-tashkeel / tashkeel_tokenizer.py
MohamedRashad's picture
chore: Add requirements for shakkala and kaldialign
bcc0c7f
import re
import bw2ar
import torch
import xer
# Diacritics
FATHATAN = u'\u064b'
DAMMATAN = u'\u064c'
KASRATAN = u'\u064d'
FATHA = u'\u064e'
DAMMA = u'\u064f'
KASRA = u'\u0650'
SHADDA = u'\u0651'
SUKUN = u'\u0652'
TATWEEL = u'\u0640'
HARAKAT_PAT = re.compile(u"["+u"".join([FATHATAN, DAMMATAN, KASRATAN,
FATHA, DAMMA, KASRA, SUKUN,
SHADDA])+u"]")
class TashkeelTokenizer:
def __init__(self):
self.letters = [' ', '$', '&', "'", '*', '<', '>', 'A', 'D', 'E', 'H', 'S', 'T', 'Y', 'Z',
'b', 'd', 'f', 'g', 'h', 'j', 'k', 'l', 'm', 'n', 'p', 'q', 'r', 's', 't',
'v', 'w', 'x', 'y', 'z', '|', '}'
]
self.letters = ['<PAD>', '<BOS>', '<EOS>'] + self.letters + ['<MASK>']
self.no_tashkeel_tag = '<NT>'
self.tashkeel_list = ['<NT>', '<SD>', '<SDD>', '<SF>', '<SFF>', '<SK>',
'<SKK>', 'F', 'K', 'N', 'a', 'i', 'o', 'u', '~']
self.tashkeel_list = ['<PAD>', '<BOS>', '<EOS>'] + self.tashkeel_list
self.tashkeel_map = {c:i for i,c in enumerate(self.tashkeel_list)}
self.letters_map = {c:i for i,c in enumerate(self.letters)}
self.inverse_tags = {
'~a': '<SF>', # shaddah and fatHa
'~u': '<SD>', # shaddah and Damma
'~i': '<SK>', # shaddah and kasra
'~F': '<SFF>', # shaddah and fatHatayn
'~N': '<SDD>', # shaddah and Dammatayn
'~K': '<SKK>' # shaddah and kasratayn
}
self.tags = {v:k for k,v in self.inverse_tags.items()}
self.shaddah_last = ['a~', 'u~', 'i~', 'F~', 'N~', 'K~']
self.shaddah_first = ['~a', '~u', '~i', '~F', '~N', '~K']
self.tahkeel_chars = ['F','N','K','a', 'u', 'i', '~', 'o']
def clean_text(self, text):
text = re.sub(u'[%s]' % u'\u0640', '', text) # strip tatweel
text = text.replace('ٱ', 'ا')
return ' '.join(re.sub(u"[^\u0621-\u063A\u0640-\u0652\u0670\u0671\ufefb\ufef7\ufef5\ufef9 ]", " ", text, flags=re.UNICODE).split())
def check_match(self, text_with_tashkeel, letter_n_tashkeel_pairs):
text_with_tashkeel = text_with_tashkeel.strip()
# test if the reconstructed text with tashkeel is the same as the original one
syn_text = self.combine_tashkeel_with_text(letter_n_tashkeel_pairs)
return syn_text == text_with_tashkeel or syn_text == self.unify_shaddah_position(text_with_tashkeel)
def unify_shaddah_position(self, text_with_tashkeel):
# unify the order of shaddah and the harakah to make shaddah always at the beginning
for i in range(len(self.shaddah_first)):
text_with_tashkeel = text_with_tashkeel.replace(self.shaddah_last[i], self.shaddah_first[i])
return text_with_tashkeel
def split_tashkeel_from_text(self, text_with_tashkeel, test_match=True):
text_with_tashkeel = self.clean_text(text_with_tashkeel)
text_with_tashkeel = bw2ar.transliterate_text(text_with_tashkeel, 'ar2bw')
text_with_tashkeel = text_with_tashkeel.replace('`', '') # remove dagger 'alif
# unify the order of shaddah and the harakah to make shaddah always at the beginning
text_with_tashkeel = self.unify_shaddah_position(text_with_tashkeel)
# remove duplicated harakat
for i in range(len(self.tahkeel_chars)):
text_with_tashkeel = text_with_tashkeel.replace(self.tahkeel_chars[i]*2, self.tahkeel_chars[i])
letter_n_tashkeel_pairs = []
for i in range(len(text_with_tashkeel)): # go over the whole text
# check if the first character is a normal letter and the second character is a tashkeel
if i < (len(text_with_tashkeel) - 1) and not text_with_tashkeel[i] in self.tashkeel_list and text_with_tashkeel[i+1] in self.tashkeel_list:
# IMPORTANT: check if tashkeel is Shaddah, then there might be another Tashkeel char associated with it. If so,
# replace both Shaddah and the Tashkeel chars with the appropriate tag
if text_with_tashkeel[i+1] == '~':
# IMPORTANT: the following if statement depends on the concept of short circuit!!
# The first condition checks if there are still more chars before it access position i+2
# "text_with_tashkeel[i+2]" since it causes "index out of range" exception. Notice that
# Shaddah here is put in the first position before the Harakah.
if i+2 < len(text_with_tashkeel) and f'~{text_with_tashkeel[i+2]}' in self.inverse_tags:
letter_n_tashkeel_pairs.append((text_with_tashkeel[i], self.inverse_tags[f'~{text_with_tashkeel[i+2]}']))
else:
# if it is only Shaddah, just add it to the list
letter_n_tashkeel_pairs.append((text_with_tashkeel[i], '~'))
else:
letter_n_tashkeel_pairs.append((text_with_tashkeel[i], text_with_tashkeel[i+1]))
# if the character at position i is a normal letter and has no Tashkeel, then add
# it with the tag "self.no_tashkeel_tag"
# IMPORTANT: this elif block ensures also that there is no two or more consecutive tashkeel other than shaddah
elif not text_with_tashkeel[i] in self.tashkeel_list:
letter_n_tashkeel_pairs.append((text_with_tashkeel[i], self.no_tashkeel_tag))
if test_match:
# test if the split is done correctly by ensuring that we can retrieve back the original text
assert self.check_match(text_with_tashkeel, letter_n_tashkeel_pairs)
return [('<BOS>', '<BOS>')] + letter_n_tashkeel_pairs + [('<EOS>', '<EOS>')]
def combine_tashkeel_with_text(self, letter_n_tashkeel_pairs):
combined_with_tashkeel = []
for letter, tashkeel in letter_n_tashkeel_pairs:
combined_with_tashkeel.append(letter)
if tashkeel in self.tags:
combined_with_tashkeel.append(self.tags[tashkeel])
elif tashkeel != self.no_tashkeel_tag:
combined_with_tashkeel.append(tashkeel)
text = ''.join(combined_with_tashkeel)
return text
def encode(self, text_with_tashkeel, test_match=True):
letter_n_tashkeel_pairs = self.split_tashkeel_from_text(text_with_tashkeel, test_match)
text, tashkeel = zip(*letter_n_tashkeel_pairs)
input_ids = [self.letters_map[c] for c in text]
target_ids = [self.tashkeel_map[c] for c in tashkeel]
return torch.LongTensor(input_ids), torch.LongTensor(target_ids)
def filter_tashkeel(self, tashkeel):
tmp = []
for i, t in enumerate(tashkeel):
if i != 0 and t == '<BOS>':
t = self.no_tashkeel_tag
elif i != (len(tashkeel) - 1) and t == '<EOS>':
t = self.no_tashkeel_tag
tmp.append(t)
tashkeel = tmp
return tashkeel
def decode(self, input_ids, target_ids):
# print('input_ids.shape:', input_ids.shape)
# print('target_ids.shape:', target_ids.shape)
input_ids = input_ids.cpu().tolist()
target_ids = target_ids.cpu().tolist()
ar_texts = []
for j in range(len(input_ids)):
letters = [self.letters[i] for i in input_ids[j]]
tashkeel = [self.tashkeel_list[i] for i in target_ids[j]]
letters = list(filter(lambda x: x != '<BOS>' and x != '<EOS>' and x != '<PAD>', letters))
tashkeel = self.filter_tashkeel(tashkeel)
tashkeel = list(filter(lambda x: x != '<BOS>' and x != '<EOS>' and x != '<PAD>', tashkeel))
# VERY IMPORTANT NOTE: zip takes min(len(letters), len(tashkeel)) and discard the reset of letters / tashkeels
letter_n_tashkeel_pairs = list(zip(letters, tashkeel))
bw_text = self.combine_tashkeel_with_text(letter_n_tashkeel_pairs)
ar_text = bw2ar.transliterate_text(bw_text, 'bw2ar')
ar_texts.append(ar_text)
return ar_texts
def get_tashkeel_with_case_ending(self, text, case_ending=True):
text_split = self.split_tashkeel_from_text(text, test_match=False)
text_spaces_indecies = [i for i, el in enumerate(text_split) if el == (' ', '<NT>')]
new_text_split = []
for i, el in enumerate(text_split):
if not case_ending and (i+1) in text_spaces_indecies:
el = (el[0], '<NT>') # no case ending
new_text_split.append(el)
letters, tashkeel = zip(*new_text_split)
return letters, tashkeel
def compute_der(self, ref, hyp, case_ending=True):
_, ref_tashkeel = self.get_tashkeel_with_case_ending(ref, case_ending=case_ending)
_, hyp_tashkeel = self.get_tashkeel_with_case_ending(hyp, case_ending=case_ending)
ref_tashkeel = ' '.join(ref_tashkeel)
hyp_tashkeel = ' '.join(hyp_tashkeel)
return xer.wer(ref_tashkeel, hyp_tashkeel)
def compute_wer(self, ref, hyp, case_ending=True):
ref_letters, ref_tashkeel = self.get_tashkeel_with_case_ending(ref, case_ending=case_ending)
hyp_letters, hyp_tashkeel = self.get_tashkeel_with_case_ending(hyp, case_ending=case_ending)
ref_text_combined = self.combine_tashkeel_with_text(zip(ref_letters, ref_tashkeel))
hyp_text_combined = self.combine_tashkeel_with_text(zip(hyp_letters, hyp_tashkeel))
return xer.wer(ref_text_combined, hyp_text_combined)
def remove_tashkeel(self, text):
text = HARAKAT_PAT.sub('', text)
text = re.sub(u"[\u064E]", "", text, flags=re.UNICODE) # fattha
text = re.sub(u"[\u0671]", "", text, flags=re.UNICODE) # waSla
return text
if __name__ == '__main__':
import utils
from tqdm import tqdm
tokenizer = TashkeelTokenizer()
txt_folder_path = 'dataset/train'
prepared_lines = []
for filepath in utils.get_files(txt_folder_path, '*.txt'):
print(f'Reading file: {filepath}')
with open(filepath) as f1:
for line in f1:
clean_line = tokenizer.clean_text(line)
if clean_line != '':
prepared_lines.append(clean_line)
print(f'completed file: {filepath}')
good_sentences = []
bad_sentences = []
tokenized_sentences = []
for line in tqdm(prepared_lines):
try:
letter_n_tashkeel_pairs = tokenizer.split_tashkeel_from_text(line, test_match=True)
tokenized_sentences.append(letter_n_tashkeel_pairs)
good_sentences.append(line)
except AssertionError as e:
bad_sentences.append(line)
print('len(good_sentences), len(bad_sentences):', len(good_sentences), len(bad_sentences))