jeduardogruiz
commited on
Upload test_tokenization_wav2vec2_phoneme.py
Browse files
test_tokenization_wav2vec2_phoneme.py
ADDED
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# See the License for the specific language governing permissions and
|
3 |
+
# limitations under the License.
|
4 |
+
"""Tests for the Wav2Vec2Phoneme tokenizer."""
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import unittest
|
8 |
+
from typing import Tuple
|
9 |
+
|
10 |
+
from transformers import Wav2Vec2PhonemeCTCTokenizer
|
11 |
+
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
|
12 |
+
from transformers.models.wav2vec2_phoneme.tokenization_wav2vec2_phoneme import Wav2Vec2PhonemeCTCTokenizerOutput
|
13 |
+
from transformers.testing_utils import require_phonemizer
|
14 |
+
|
15 |
+
from test_tokenization_common import TokenizerTesterMixin
|
16 |
+
|
17 |
+
|
18 |
+
@require_phonemizer
|
19 |
+
class Wav2Vec2PhonemeCTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
20 |
+
from_pretrained_id = "facebook/wav2vec2-lv-60-espeak-cv-ft"
|
21 |
+
tokenizer_class = Wav2Vec2PhonemeCTCTokenizer
|
22 |
+
test_rust_tokenizer = False
|
23 |
+
|
24 |
+
def setUp(self):
|
25 |
+
super().setUp()
|
26 |
+
|
27 |
+
vocab = (
|
28 |
+
"<s> <pad> </s> <unk> n s t ə l a i k d m ɛ ɾ e ɪ p o ɐ z ð f j v b ɹ ʁ ʊ iː r w ʌ u ɡ æ aɪ ʃ h ɔ ɑː "
|
29 |
+
"ŋ ɚ eɪ β uː y ɑ̃ oʊ ᵻ eː θ aʊ ts oː ɔ̃ ɣ ɜ ɑ dʒ əl x ɜː ç ʒ tʃ ɔː ɑːɹ ɛ̃ ʎ ɔːɹ ʋ aː ɕ œ ø oːɹ ɲ yː "
|
30 |
+
"ʔ iə i5 s. tɕ ?? nʲ ɛː œ̃ ɭ ɔø ʑ tʲ ɨ ɛɹ ts. rʲ ɪɹ ɭʲ i.5 ɔɪ q sʲ u5 ʊɹ iɜ a5 iɛ5 øː ʕ ja əɜ th ɑ5 "
|
31 |
+
"oɪ dʲ ə5 tɕh ts.h mʲ ɯ dʑ vʲ e̞ tʃʲ ei5 o5 onɡ5 ɑu5 iɑ5 ai5 aɪɚ kh ə1 ʐ i2 ʉ ħ t[ aɪə ʲ ju ə2 u2 oɜ "
|
32 |
+
"pː iɛɜ ou5 y5 uɜ tː uo5 d[ uoɜ tsh ɑɜ ɵ i̪5 uei5 ɟ aɜ ɑɨ i.ɜ eʊ o2 ɐ̃ ä pʲ kʲ n̩ ɒ ph ɑu2 uɨ əɪ ɫ ɬ "
|
33 |
+
"yɜ bʲ ɑ2 s̪ aiɜ χ ɐ̃ʊ̃ 1 ə4 yæɜ a2 ɨː t̪ iouɜ ũ onɡɜ aɨ iɛ2 ɔɨ ɑuɜ o̞ ei2 iou2 c kː y2 ɖ oe dˤ yɛɜ "
|
34 |
+
'əʊ S ɡʲ onɡ2 u" eiɜ ʈ ɯᵝ iou5 dZ r̝̊ i.2 tS s^ ʝ yə5 iɑɜ uə5 pf ɨu iɑ2 ou2 ər2 fʲ ai2 r̝ uəɜ ɳ əɨ '
|
35 |
+
"ua5 uɪ ɽ bː yu5 uo2 yɛ5 l̩ ɻ ərɜ ʂ i̪2 ouɜ uaɜ a. a.ː yæ5 dː r̩ ee ɪu ər5 i̪ ɜ æi u: i.ː t^ o1 ɪ^ "
|
36 |
+
"ai ueiɜ æː ɛɪ eə i. ɴ ie ua2 ɑ1 o4 tʃː o: ɑ: u1 N i̪1 au yæ2 u. qː yəɜ y: kʰ tʃʰ iʊ sx õ uo tʰ "
|
37 |
+
"uai5 bʰ u.ː uə2 ʊə d^ s̪ː yiɜ dʰ r. oe: i1 ɟː yu2 nʲʲ i̪4 uei2 tsʲ ɸ ĩ ɑ4 t̪ː eɑ u4 e: tsː ʈʰ ɡʰ "
|
38 |
+
"ɯɯ dʒʲ ʂʲ X ɵː uaiɜ tɕʲ ã t^ː ẽː yɛ2 cː i.1 ɛʊ dˤdˤ dʒː i4 ɡː yi ɕʲ ɟʰ pʰ dʑʲ yuɜ ua1 ua4 æiː ɐɐ "
|
39 |
+
"ui iou1 ʊː a1 iou4 cʰ iɛ1 yə2 ɖʰ ẽ ʒʲ ää ər4 iːː ɪː iɑ1 ər1 œː øi ɪuː cʰcʰ əː1 iː1 ũ kʰː o̞o̞ xʲ "
|
40 |
+
"ou1 iɛ4 e̞e̞ y1 dzː dʲʲ dʰː ɯᵝɯᵝ lː uo1 i.4 i: yɛ5ʲ a4"
|
41 |
+
).split(" ")
|
42 |
+
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
43 |
+
|
44 |
+
self.special_tokens_map = {"pad_token": "<pad>", "unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
45 |
+
|
46 |
+
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
|
47 |
+
with open(self.vocab_file, "w", encoding="utf-8") as fp:
|
48 |
+
fp.write(json.dumps(vocab_tokens) + "\n")
|
49 |
+
|
50 |
+
# overwrite since phonemes require specific creation
|
51 |
+
def get_clean_sequence(self, tokenizer, with_prefix_space=False, max_length=20, min_length=5) -> Tuple[str, list]:
|
52 |
+
toks = [(i, tokenizer.decode([i], clean_up_tokenization_spaces=False)) for i in range(len(tokenizer))]
|
53 |
+
toks = list(filter(lambda t: [t[0]] == tokenizer.encode(t[1], do_phonemize=False), toks))
|
54 |
+
if max_length is not None and len(toks) > max_length:
|
55 |
+
toks = toks[:max_length]
|
56 |
+
if min_length is not None and len(toks) < min_length and len(toks) > 0:
|
57 |
+
while len(toks) < min_length:
|
58 |
+
toks = toks + toks
|
59 |
+
# toks_str = [t[1] for t in toks]
|
60 |
+
toks_ids = [t[0] for t in toks]
|
61 |
+
|
62 |
+
# Ensure consistency
|
63 |
+
output_txt = tokenizer.decode(toks_ids, clean_up_tokenization_spaces=False)
|
64 |
+
if " " not in output_txt and len(toks_ids) > 1:
|
65 |
+
output_txt = (
|
66 |
+
tokenizer.decode([toks_ids[0]], clean_up_tokenization_spaces=False)
|
67 |
+
+ " "
|
68 |
+
+ tokenizer.decode(toks_ids[1:], clean_up_tokenization_spaces=False)
|
69 |
+
)
|
70 |
+
if with_prefix_space:
|
71 |
+
output_txt = " " + output_txt
|
72 |
+
output_ids = tokenizer.encode(output_txt, add_special_tokens=False)
|
73 |
+
return output_txt, output_ids
|
74 |
+
|
75 |
+
def get_tokenizer(self, **kwargs):
|
76 |
+
kwargs.update(self.special_tokens_map)
|
77 |
+
return Wav2Vec2PhonemeCTCTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
78 |
+
|
79 |
+
def test_tokenizer_add_new_tokens(self):
|
80 |
+
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-lv-60-espeak-cv-ft")
|
81 |
+
|
82 |
+
# check adding a single token
|
83 |
+
tokenizer.add_tokens("xxx")
|
84 |
+
token_ids = tokenizer("m xxx ɪ", do_phonemize=False).input_ids
|
85 |
+
self.assertEqual(token_ids, [13, 392, 17]) # xxx should be last token
|
86 |
+
|
87 |
+
tokenizer.add_tokens(["aaa", "bbb", "ccc"])
|
88 |
+
token_ids = tokenizer("m aaa ɪ ccc", do_phonemize=False).input_ids
|
89 |
+
self.assertEqual(token_ids, [13, 393, 17, 395]) # aaa and ccc should be after xxx and 2 after aaa
|
90 |
+
|
91 |
+
token_ids = tokenizer("maɪ c", do_phonemize=False).input_ids
|
92 |
+
self.assertEqual(token_ids, [3, 200]) # mai should be <unk> (=3)
|
93 |
+
|
94 |
+
def test_phonemize(self):
|
95 |
+
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-lv-60-espeak-cv-ft")
|
96 |
+
|
97 |
+
input_text = "Hello how are you"
|
98 |
+
phonemes = tokenizer.phonemize(input_text, phonemizer_lang="en-us")
|
99 |
+
self.assertEqual(phonemes, "h ə l oʊ h aʊ ɑːɹ j uː")
|
100 |
+
|
101 |
+
def test_encode(self):
|
102 |
+
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-lv-60-espeak-cv-ft")
|
103 |
+
|
104 |
+
input_text = "Hello how are you"
|
105 |
+
phonemes = tokenizer.phonemize(input_text, phonemizer_lang="en-us")
|
106 |
+
self.assertEqual(tokenizer(input_text).input_ids, tokenizer(phonemes, do_phonemize=False).input_ids)
|
107 |
+
|
108 |
+
def test_encode_decode(self):
|
109 |
+
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-lv-60-espeak-cv-ft")
|
110 |
+
input_text = "Hello how are you"
|
111 |
+
phonemes = tokenizer.phonemize(input_text, phonemizer_lang="en-us")
|
112 |
+
|
113 |
+
phonemes_enc_dec = tokenizer.decode(tokenizer(input_text).input_ids)
|
114 |
+
|
115 |
+
self.assertEqual(phonemes, phonemes_enc_dec)
|
116 |
+
|
117 |
+
def test_decode(self):
|
118 |
+
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-lv-60-espeak-cv-ft")
|
119 |
+
|
120 |
+
sample_ids = [
|
121 |
+
[11, 5, 15, tokenizer.pad_token_id, 15, 8, 98],
|
122 |
+
[24, 22, 5, 24, 22, 5, 77],
|
123 |
+
]
|
124 |
+
tokens = tokenizer.decode(sample_ids[0])
|
125 |
+
batch_tokens = tokenizer.batch_decode(sample_ids)
|
126 |
+
self.assertEqual(tokens, batch_tokens[0])
|
127 |
+
self.assertEqual(batch_tokens, ["k s ɾ ɾ l ɭʲ", "j ð s j ð s oːɹ"])
|
128 |
+
|
129 |
+
def test_phonemize_with_word_del(self):
|
130 |
+
tokenizer = self.tokenizer_class.from_pretrained(
|
131 |
+
"facebook/wav2vec2-lv-60-espeak-cv-ft", word_delimiter_token="|"
|
132 |
+
)
|
133 |
+
tokenizer.add_tokens("|")
|
134 |
+
|
135 |
+
input_text = "Hello how are you"
|
136 |
+
phonemes = tokenizer.phonemize(input_text, phonemizer_lang="en-us")
|
137 |
+
self.assertEqual(phonemes, "h ə l oʊ | h aʊ | ɑːɹ | j uː |")
|
138 |
+
|
139 |
+
def test_encode_with_del(self):
|
140 |
+
tokenizer = self.tokenizer_class.from_pretrained(
|
141 |
+
"facebook/wav2vec2-lv-60-espeak-cv-ft", word_delimiter_token="|"
|
142 |
+
)
|
143 |
+
tokenizer.add_tokens("|")
|
144 |
+
|
145 |
+
input_text = "Hello how are you"
|
146 |
+
phonemes = tokenizer.phonemize(input_text, phonemizer_lang="en-us")
|
147 |
+
self.assertEqual(tokenizer(input_text).input_ids, tokenizer(phonemes, do_phonemize=False).input_ids)input_text = "Hello how are you"
|
148 |
+
phonemes = tokenizer.phonemize(input_text, phonemizer_lang="en-us")
|
149 |
+
self.assertEqual(tokenizer(input_text).input_ids, tokenizer(phonemes, do_phonemize=False).input_ids)
|
150 |
+
|
151 |
+
def test_encode_decode(self):
|
152 |
+
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-lv-60-espeak-cv-ft")
|
153 |
+
input_text = "Hello how are you"
|
154 |
+
phonemes = tokenizer.phonemize(input_text, phonemizer_lang="en-us")
|
155 |
+
|
156 |
+
phonemes_enc_dec = tokenizer.decode(tokenizer(input_text).input_ids)
|
157 |
+
|
158 |
+
self.assertEqual(phonemes, phonemes_enc_dec)
|
159 |
+
|
160 |
+
def test_decode(self):
|
161 |
+
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-lv-60-espeak-cv-ft")
|
162 |
+
|
163 |
+
sample_ids = [
|
164 |
+
[11, 5, 15, tokenizer.pad_token_id, 15, 8, 98],
|
165 |
+
[24, 22, 5, 24, 22, 5, 77],
|
166 |
+
]
|
167 |
+
tokens = tokenizer.decode(sample_ids[0])
|
168 |
+
batch_tokens = tokenizer.batch_decode(sample_ids)
|
169 |
+
self.assertEqual(tokens, batch_tokens[0])
|
170 |
+
self.assertEqual(batch_tokens, ["k s ɾ ɾ l ɭʲ", "j ð s j ð s oːɹ"])
|
171 |
+
|
172 |
+
def test_phonemize_with_word_del(self):
|
173 |
+
tokenizer = self.tokenizer_class.from_pretrained(
|
174 |
+
"facebook/wav2vec2-lv-60-espeak-cv-ft", word_delimiter_token="|"
|
175 |
+
)
|
176 |
+
tokenizer.add_tokens("|")
|
177 |
+
|
178 |
+
input_text = "Hello how are you"
|
179 |
+
phonemes = tokenizer.phonemize(input_text, phonemizer_lang="en-us")
|
180 |
+
self.assertEqual(phonemes, "h ə l oʊ | h aʊ | ɑːɹ | j uː |")
|
181 |
+
|
182 |
+
def test_encode_with_del(self):
|
183 |
+
tokenizer = self.tokenizer_class.from_pretrained(
|
184 |
+
"facebook/wav2vec2-lv-60-espeak-cv-ft", word_delimiter_token="|"
|
185 |
+
)
|
186 |
+
tokenizer.add_tokens("|")
|
187 |
+
|
188 |
+
input_text = "Hello how are you"
|
189 |
+
phonemes = tokenizer.phonemize(input_text, phonemizer_lang="en-us")
|
190 |
+
self.assertEqual(tokenizer(input_text).input_ids, tokenizer(phonemes, do_phonemize=False).input_ids)
|
191 |
+
|
192 |
+
def test_decode_with_del(self):
|
193 |
+
tokenizer = self.tokenizer_class.from_pretrained(
|
194 |
+
"facebook/wav2vec2-lv-60-espeak-cv-ft", word_delimiter_token="|"
|
195 |
+
)
|
196 |
+
tokenizer.add_tokens("|")
|
197 |
+
|
198 |
+
# fmt: off
|
199 |
+
sample_ids = [
|
200 |
+
[11, 5, 15, tokenizer.pad_token_id, tokenizer.word_delimiter_token_id, 15, 8, tokenizer.word_delimiter_token_id, 98],
|
201 |
+
[tokenizer.word_delimiter_token_id, 24, 22, tokenizer.word_delimiter_token_id, 5, 24, 22, 5, 77],
|
202 |
+
]
|
203 |
+
# fmt: on
|
204 |
+
|
205 |
+
# decode with word_del_token filter
|
206 |
+
tokens = tokenizer.decode(sample_ids[0])
|
207 |
+
batch_tokens = tokenizer.batch_decode(sample_ids)
|
208 |
+
self.assertEqual(tokens, batch_tokens[0])
|
209 |
+
self.assertEqual(batch_tokens, ["k s ɾ ɾ l ɭʲ", "j ð s j ð s oːɹ"])
|
210 |
+
|
211 |
+
# decode with no word_del_token filter
|
212 |
+
tokens = tokenizer.decode(sample_ids[0], filter_word_delimiter_token=False)
|
213 |
+
batch_tokens = tokenizer.batch_decode(sample_ids, filter_word_delimiter_token=False)
|
214 |
+
self.assertEqual(tokens, batch_tokens[0])
|
215 |
+
self.assertEqual(batch_tokens, ["k s ɾ | ɾ l | ɭʲ", "| j ð | s j ð s oːɹ"])
|
216 |
+
|
217 |
+
def test_encode_decode_with_del(self):
|
218 |
+
tokenizer = self.tokenizer_class.from_pretrained(
|
219 |
+
"facebook/wav2vec2-lv-60-espeak-cv-ft", word_delimiter_token="|"
|
220 |
+
)
|
221 |
+
tokenizer.add_tokens("|")
|
222 |
+
|
223 |
+
input_text = "Hello how are you"
|
224 |
+
phonemes = tokenizer.phonemize(input_text, phonemizer_lang="en-us")
|
225 |
+
|
226 |
+
phonemes_enc_dec = tokenizer.decode(tokenizer(input_text).input_ids, filter_word_delimiter_token=False)
|
227 |
+
|
228 |
+
self.assertEqual(phonemes, phonemes_enc_dec)
|
229 |
+
|
230 |
+
def test_encode_decode_with_del_filter(self):
|
231 |
+
tokenizer = self.tokenizer_class.from_pretrained(
|
232 |
+
"facebook/wav2vec2-lv-60-espeak-cv-ft", word_delimiter_token="|"
|
233 |
+
)
|
234 |
+
tokenizer.add_tokens("|")
|
235 |
+
|
236 |
+
input_text = "Hello how are you"
|
237 |
+
phonemes = tokenizer.phonemize(input_text, phonemizer_lang="en-us")
|
238 |
+
|
239 |
+
phonemes_enc_dec = tokenizer.decode(tokenizer(input_text).input_ids, filter_word_delimiter_token=True)
|
240 |
+
|
241 |
+
self.assertEqual(" ".join([p.strip() for p in phonemes.split(" |")]).strip(), phonemes_enc_dec)
|
242 |
+
|
243 |
+
def test_change_phonemizer_lang(self):
|
244 |
+
tokenizer = self.tokenizer_class.from_pretrained(
|
245 |
+
"facebook/wav2vec2-lv-60-espeak-cv-ft", word_delimiter_token=None
|
246 |
+
)
|
247 |
+
input_text = "Hello how are you"
|
248 |
+
|
249 |
+
input_ids_en = tokenizer(input_text, phonemizer_lang="en-us").input_ids
|
250 |
+
input_ids_fr = tokenizer(input_text, phonemizer_lang="fr-fr").input_ids
|
251 |
+
|
252 |
+
self.assertNotEqual(input_ids_en, input_ids_fr)
|
253 |
+
|
254 |
+
text_en = tokenizer.decode(input_ids_en)
|
255 |
+
text_fr = tokenizer.decode(input_ids_fr)
|
256 |
+
|
257 |
+
self.assertEqual(text_en, "h ə l oʊ h aʊ ɑːɹ j uː")
|
258 |
+
self.assertEqual(text_fr, "ɛ l o h aʊ a ʁ j u")
|
259 |
+
|
260 |
+
def test_case_insensitive(self):
|
261 |
+
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-lv-60-espeak-cv-ft")
|
262 |
+
input_text_up = "Hello how Are you"
|
263 |
+
input_text_low = "hello how are you"
|
264 |
+
|
265 |
+
input_ids_up = tokenizer(input_text_up).input_ids
|
266 |
+
input_ids_low = tokenizer(input_text_low).input_ids
|
267 |
+
|
268 |
+
self.assertEqual(input_ids_up, input_ids_low)
|
269 |
+
|
270 |
+
def test_tokenizer_decode_added_tokens(self):
|
271 |
+
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-lv-60-espeak-cv-ft")
|
272 |
+
tokenizer.add_tokens(["!", "?"])
|
273 |
+
tokenizer.add_special_tokens({"cls_token": "$$$"})
|
274 |
+
|
275 |
+
# fmt: off
|
276 |
+
sample_ids = [
|
277 |
+
[11, 5, 15, tokenizer.pad_token_id, 15, 8, 98, 392, 392, 393, 392, 392, 393, 394, 394],
|
278 |
+
[24, 22, 5, 24, 22, 5, 77, tokenizer.pad_token_id, 394, 394],
|
279 |
+
]
|
280 |
+
# fmt: on
|
281 |
+
|
282 |
+
batch_tokens = tokenizer.batch_decode(sample_ids)
|
283 |
+
self.assertEqual(batch_tokens, ["k s ɾ ɾ l ɭʲ!?!? $$$", "j ð s j ð s oːɹ $$$"])
|
284 |
+
|
285 |
+
@staticmethod
|
286 |
+
def get_from_offsets(offsets, key):
|
287 |
+
retrieved_list = [d[key] for d in offsets]
|
288 |
+
return retrieved_list
|
289 |
+
|
290 |
+
def test_offsets(self):
|
291 |
+
tokenizer = self.get_tokenizer(word_delimiter_token="|")
|
292 |
+
tokenizer.add_tokens("|")
|
293 |
+
|
294 |
+
# fmt: off
|
295 |
+
# ksssɾɾ|ɾɾ<pad>ɾɾ|<pad>ɾlll|ɭʲ -> k s ɾ ɾ | ɾ l | ɭʲ"
|
296 |
+
sample_ids = [11, 5, 5, 5, 15, 15, tokenizer.pad_token_id, 15, 15, tokenizer.word_delimiter_token_id, tokenizer.pad_token_id, 15, 8, 8, 8, tokenizer.word_delimiter_token_id, 98]
|
297 |
+
# fmt: on
|
298 |
+
|
299 |
+
outputs = tokenizer.decode(sample_ids, output_char_offsets=True, filter_word_delimiter_token=False)
|
300 |
+
# check Wav2Vec2CTCTokenizerOutput keys for char
|
301 |
+
self.assertEqual(len(outputs.keys()), 2)
|
302 |
+
self.assertTrue("text" in outputs)
|
303 |
+
self.assertTrue("char_offsets" in outputs)
|
304 |
+
self.assertTrue(isinstance(outputs, Wav2Vec2PhonemeCTCTokenizerOutput))
|
305 |
+
|
306 |
+
# check that order of chars is correct and identical for both outputs
|
307 |
+
self.assertEqual(" ".join(self.get_from_offsets(outputs["char_offsets"], "char")), outputs.text)
|
308 |
+
self.assertListEqual(
|
309 |
+
self.get_from_offsets(outputs["char_offsets"], "char"), ["k", "s", "ɾ", "ɾ", "|", "ɾ", "l", "|", "ɭʲ"]
|
310 |
+
)
|
311 |
+
|
312 |
+
# check that offsets are actually correct for char
|
313 |
+
# 0-1 is 11, 1-4 is 5, 4-6 is first 15, 6-7 is <pad> (thus not shown), 7-9 is second 15, 9-10 is word_delimiter_token,
|
314 |
+
# 10-11 is <pad> (thus not shown), 11-12 is third 15, 12-15 is 8, 15-16 is word_delimiter_token, 16-17 is 98
|
315 |
+
self.assertListEqual(
|
316 |
+
self.get_from_offsets(outputs["char_offsets"], "start_offset"), [0, 1, 4, 7, 9, 11, 12, 15, 16]
|
317 |
+
)
|
318 |
+
self.assertListEqual(
|
319 |
+
self.get_from_offsets(outputs["char_offsets"], "end_offset"), [1, 4, 6, 9, 10, 12, 15, 16, 17]
|
320 |
+
)
|
321 |
+
|
322 |
+
def test_offsets_batch(self):
|
323 |
+
tokenizer = self.get_tokenizer(word_delimiter_token="|")
|
324 |
+
|
325 |
+
def check_list_tuples_equal(outputs_batch, outputs_list):
|
326 |
+
self.assertTrue(isinstance(outputs_batch, Wav2Vec2PhonemeCTCTokenizerOutput))
|
327 |
+
self.assertTrue(isinstance(outputs_list[0], Wav2Vec2PhonemeCTCTokenizerOutput))
|
328 |
+
|
329 |
+
# transform list to ModelOutput
|
330 |
+
outputs_batch_2 = Wav2Vec2PhonemeCTCTokenizerOutput(
|
331 |
+
{k: [d[k] for d in outputs_list] for k in outputs_list[0]}
|
332 |
+
)
|
333 |
+
|
334 |
+
self.assertListEqual(outputs_batch["text"], outputs_batch_2["text"])
|
335 |
+
|
336 |
+
def recursive_check(list_or_dict_1, list_or_dict_2):
|
337 |
+
if isinstance(list_or_dict_1, list):
|
338 |
+
[recursive_check(l1, l2) for l1, l2 in zip(list_or_dict_1, list_or_dict_2)]
|
339 |
+
self.assertEqual(list_or_dict_1, list_or_dict_2)
|
340 |
+
|
341 |
+
if "char_offsets" in outputs_batch:
|
342 |
+
recursive_check(outputs_batch["char_offsets"], outputs_batch_2["char_offsets"])
|
343 |
+
|
344 |
+
# fmt: off
|
345 |
+
sample_ids = [
|
346 |
+
[11, 5, 15, tokenizer.pad_token_id, 15, 4, 8, 98, 32, 32, 32, 32, 4, 33, tokenizer.word_delimiter_token_id, 32, 32, 33, 34, 34],
|
347 |
+
[24, 22, 5, tokenizer.word_delimiter_token_id, tokenizer.word_delimiter_token_id, 24, 22, 22, 22, 4, 5, 77, tokenizer.pad_token_id, 22, 22, 4, 34, 34, 34, 34],
|
348 |
+
]
|
349 |
+
# fmt: on
|
350 |
+
|
351 |
+
# We assume that `decode` works as expected. All we will check now is
|
352 |
+
# the output type is correct and the output is identical to `decode`
|
353 |
+
|
354 |
+
# char
|
355 |
+
outputs_char_batch = tokenizer.batch_decode(sample_ids, output_char_offsets=True)
|
356 |
+
outputs_char = [tokenizer.decode(ids, output_char_offsets=True) for ids in sample_ids]
|
357 |
+
check_list_tuples_equal(outputs_char_batch, outputs_char)
|
358 |
+
|
359 |
+
@unittest.skip("Wav2Vec2PhonemeTokenizer always lower cases letters to correctly map to phonemes")
|
360 |
+
def test_added_tokens_do_lower_case(self):
|
361 |
+
pass
|
362 |
+
|
363 |
+
@unittest.skip("Wav2Vec2PhonemeTokenizer always puts spaces between phonemes")
|
364 |
+
def test_encode_decode_with_spaces(self):
|
365 |
+
pass
|
366 |
+
|
367 |
+
@unittest.skip("encodes to text to ids, but decodes ids to phonemes -> not possible to have internal consistency")
|
368 |
+
def test_internal_consistency(self):
|
369 |
+
pass
|
370 |
+
|
371 |
+
@unittest.skip("Wav2Vec2PhonemeModel has no max model length => no testing")
|
372 |
+
def test_add_tokens_tokenizer(self):
|
373 |
+
tokenizers = self.get_tokenizers(do_lower_case=False)
|
374 |
+
for tokenizer in tokenizers:
|
375 |
+
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
376 |
+
vocab_size = tokenizer.vocab_size
|
377 |
+
all_size = len(tokenizer)
|
378 |
+
|
379 |
+
self.assertNotEqual(vocab_size, 0)
|
380 |
+
|
381 |
+
# We usually have added tokens from the start in tests because our vocab fixtures are
|
382 |
+
# smaller than the original vocabs - let's not assert this
|
383 |
+
# self.assertEqual(vocab_size, all_size)
|
384 |
+
|
385 |
+
new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd"]
|
386 |
+
added_toks = tokenizer.add_tokens(new_toks)
|
387 |
+
vocab_size_2 = tokenizer.vocab_size
|
388 |
+
all_size_2 = len(tokenizer)
|
389 |
+
|
390 |
+
self.assertNotEqual(vocab_size_2, 0)
|
391 |
+
self.assertEqual(vocab_size, vocab_size_2)
|
392 |
+
self.assertEqual(added_toks, len(new_toks))
|
393 |
+
self.assertEqual(all_size_2, all_size + len(new_toks))
|
394 |
+
|
395 |
+
tokens = tokenizer.encode("aaaaa bbbbbb low cccccccccdddddddd l", add_special_tokens=False)
|
396 |
+
|
397 |
+
self.assertGreaterEqual(len(tokens), 4)
|
398 |
+
self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
|
399 |
+
self.assertGreater(tokens[-3], tokenizer.vocab_size - 1)
|
400 |
+
|
401 |
+
new_toks_2 = {"eos_token": ">>>>|||<||<<|<<", "pad_token": "<<<<<|||>|>>>>|>"}
|
402 |
+
added_toks_2 = tokenizer.add_special_tokens(new_toks_2)
|
403 |
+
vocab_size_3 = tokenizer.vocab_size
|
404 |
+
all_size_3 = len(tokenizer)
|
405 |
+
|
406 |
+
self.assertNotEqual(vocab_size_3, 0)
|
407 |
+
self.assertEqual(vocab_size, vocab_size_3)
|
408 |
+
self.assertEqual(added_toks_2, len(new_toks_2))
|
409 |
+
self.assertEqual(all_size_3, all_size_2 + len(new_toks_2))
|
410 |
+
|
411 |
+
tokens = tokenizer.encode(
|
412 |
+
">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l", add_special_tokens=False
|
413 |
+
)
|
414 |
+
|
415 |
+
self.assertGreaterEqual(len(tokens), 6)
|
416 |
+
self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
|
417 |
+
self.assertGreater(tokens[0], tokens[1])
|
418 |
+
self.assertGreater(tokens[-3], tokenizer.vocab_size - 1)
|
419 |
+
self.assertGreater(tokens[-3], tokens[-4])
|
420 |
+
self.assertEqual(tokens[0], tokenizer.eos_token_id)
|
421 |
+
self.assertEqual(tokens[-3], tokenizer.pad_token_id)
|
422 |
+
|
423 |
+
@unittest.skip("The tokenizer shouldn't be used to encode input IDs (except for labels), only to decode.")
|
424 |
+
def test_tf_encode_plus_sent_to_model(self):
|
425 |
+
pass
|
426 |
+
|
427 |
+
@unittest.skip("The tokenizer shouldn't be used to encode input IDs (except for labels), only to decode.")
|
428 |
+
def test_torch_encode_plus_sent_to_model(self):
|
429 |
+
pass
|
430 |
+
|
431 |
+
def test_convert_tokens_to_string_format(self):
|
432 |
+
# The default common tokenizer tests assumes that the output of `convert_tokens_to_string` is a string which
|
433 |
+
# is not the case for Wav2Vec2PhonemeCTCTokenizer.
|
434 |
+
tokenizers = self.get_tokenizers(fast=True, do_lower_case=True)
|
435 |
+
for tokenizer in tokenizers:
|
436 |
+
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
437 |
+
tokens = ["ð", "ɪ", "s", "ɪ", "z", "ɐ", "t", "ɛ", "k", "s", "t"]
|
438 |
+
output = tokenizer.convert_tokens_to_string(tokens)
|
439 |
+
|
440 |
+
self.assertIsInstance(output["text"], str)
|