Spaces:
Sleeping
Sleeping
# coding=utf-8 | |
# Copyright 2022 The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Tokenization classes for Whisper.""" | |
import json | |
import os | |
import warnings | |
from functools import lru_cache | |
from typing import List, Optional, Tuple, Union | |
import numpy as np | |
import regex as re | |
from ...tokenization_utils import AddedToken, PreTrainedTokenizer | |
from ...utils import logging | |
from .english_normalizer import BasicTextNormalizer, EnglishTextNormalizer | |
VOCAB_FILES_NAMES = { | |
"vocab_file": "vocab.json", | |
"tokenizer_file": "tokenizer.json", | |
"merges_file": "merges.txt", | |
"normalizer_file": "normalizer.json", | |
} | |
MAX_MODEL_INPUT_SIZES = { | |
"openai/whisper-base": 448, | |
} | |
# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode | |
def bytes_to_unicode(): | |
""" | |
Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control | |
characters the bpe code barfs on. | |
The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab | |
if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for | |
decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup | |
tables between utf-8 bytes and unicode strings. | |
""" | |
bs = ( | |
list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) | |
) | |
cs = bs[:] | |
n = 0 | |
for b in range(2**8): | |
if b not in bs: | |
bs.append(b) | |
cs.append(2**8 + n) | |
n += 1 | |
cs = [chr(n) for n in cs] | |
return dict(zip(bs, cs)) | |
logger = logging.get_logger(__name__) | |
# Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs | |
def get_pairs(word): | |
""" | |
Return set of symbol pairs in a word. | |
Word is represented as tuple of symbols (symbols being variable-length strings). | |
""" | |
pairs = set() | |
prev_char = word[0] | |
for char in word[1:]: | |
pairs.add((prev_char, char)) | |
prev_char = char | |
return pairs | |
LANGUAGES = { | |
"en": "english", | |
"zh": "chinese", | |
"de": "german", | |
"es": "spanish", | |
"ru": "russian", | |
"ko": "korean", | |
"fr": "french", | |
"ja": "japanese", | |
"pt": "portuguese", | |
"tr": "turkish", | |
"pl": "polish", | |
"ca": "catalan", | |
"nl": "dutch", | |
"ar": "arabic", | |
"sv": "swedish", | |
"it": "italian", | |
"id": "indonesian", | |
"hi": "hindi", | |
"fi": "finnish", | |
"vi": "vietnamese", | |
"he": "hebrew", | |
"uk": "ukrainian", | |
"el": "greek", | |
"ms": "malay", | |
"cs": "czech", | |
"ro": "romanian", | |
"da": "danish", | |
"hu": "hungarian", | |
"ta": "tamil", | |
"no": "norwegian", | |
"th": "thai", | |
"ur": "urdu", | |
"hr": "croatian", | |
"bg": "bulgarian", | |
"lt": "lithuanian", | |
"la": "latin", | |
"mi": "maori", | |
"ml": "malayalam", | |
"cy": "welsh", | |
"sk": "slovak", | |
"te": "telugu", | |
"fa": "persian", | |
"lv": "latvian", | |
"bn": "bengali", | |
"sr": "serbian", | |
"az": "azerbaijani", | |
"sl": "slovenian", | |
"kn": "kannada", | |
"et": "estonian", | |
"mk": "macedonian", | |
"br": "breton", | |
"eu": "basque", | |
"is": "icelandic", | |
"hy": "armenian", | |
"ne": "nepali", | |
"mn": "mongolian", | |
"bs": "bosnian", | |
"kk": "kazakh", | |
"sq": "albanian", | |
"sw": "swahili", | |
"gl": "galician", | |
"mr": "marathi", | |
"pa": "punjabi", | |
"si": "sinhala", | |
"km": "khmer", | |
"sn": "shona", | |
"yo": "yoruba", | |
"so": "somali", | |
"af": "afrikaans", | |
"oc": "occitan", | |
"ka": "georgian", | |
"be": "belarusian", | |
"tg": "tajik", | |
"sd": "sindhi", | |
"gu": "gujarati", | |
"am": "amharic", | |
"yi": "yiddish", | |
"lo": "lao", | |
"uz": "uzbek", | |
"fo": "faroese", | |
"ht": "haitian creole", | |
"ps": "pashto", | |
"tk": "turkmen", | |
"nn": "nynorsk", | |
"mt": "maltese", | |
"sa": "sanskrit", | |
"lb": "luxembourgish", | |
"my": "myanmar", | |
"bo": "tibetan", | |
"tl": "tagalog", | |
"mg": "malagasy", | |
"as": "assamese", | |
"tt": "tatar", | |
"haw": "hawaiian", | |
"ln": "lingala", | |
"ha": "hausa", | |
"ba": "bashkir", | |
"jw": "javanese", | |
"su": "sundanese", | |
"yue": "cantonese", | |
} | |
# language code lookup by name, with a few language aliases | |
TO_LANGUAGE_CODE = { | |
**{language: code for code, language in LANGUAGES.items()}, | |
"burmese": "my", | |
"valencian": "ca", | |
"flemish": "nl", | |
"haitian": "ht", | |
"letzeburgesch": "lb", | |
"pushto": "ps", | |
"panjabi": "pa", | |
"moldavian": "ro", | |
"moldovan": "ro", | |
"sinhalese": "si", | |
"castilian": "es", | |
"mandarin": "zh", | |
} | |
TASK_IDS = ["translate", "transcribe"] | |
class WhisperTokenizer(PreTrainedTokenizer): | |
""" | |
Construct a Whisper tokenizer. | |
This tokenizer inherits from [`PreTrainedTokenizer`] which contains some of the main methods. Users should refer to | |
the superclass for more information regarding such methods. | |
Args: | |
vocab_file (`str`): | |
Path to the vocabulary file. | |
merges_file (`str`): | |
Path to the merges file. | |
normalizer_file (`str`, *optional*): | |
Path to the normalizer_file file. | |
errors (`str`, *optional*, defaults to `"replace"`): | |
Paradigm to follow when decoding bytes to UTF-8. See | |
[bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. | |
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`): | |
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this | |
token instead. | |
bos_token (`str`, *optional*, defaults to `"<|endoftext|>"`): | |
The beginning of sequence token. The `decoder_start_token_id` is used to set the first token as | |
`"<|startoftranscript|>"` when generating. | |
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`): | |
The end of sequence token. | |
pad_token (`str`, *optional*): | |
The token used for padding, for example when batching sequences of different lengths. | |
add_prefix_space (`bool`, *optional*, defaults to `False`): | |
Whether or not to add an initial space to the input. This allows to treat the leading word just as any | |
other word. | |
language (`str`, *optional*): | |
The language of the transcription text. The corresponding language id token is appended to the start of the | |
sequence for multilingual speech recognition and speech translation tasks, e.g. for Spanish the token | |
`"<|es|>"` is appended to the start of sequence. This should be used for multilingual fine-tuning only. | |
task (`str`, *optional*): | |
Task identifier to append at the start of sequence (if any). This should be used for mulitlingual | |
fine-tuning, with `"transcribe"` for speech recognition and `"translate"` for speech translation. | |
predict_timestamps (`bool`, *optional*, defaults to `False`): | |
Whether to omit the `<|notimestamps|>` token at the start of the sequence. | |
""" | |
vocab_files_names = VOCAB_FILES_NAMES | |
model_input_names = ["input_ids", "attention_mask"] | |
def __init__( | |
self, | |
vocab_file, | |
merges_file, | |
normalizer_file=None, | |
errors="replace", | |
unk_token="<|endoftext|>", | |
bos_token="<|endoftext|>", | |
eos_token="<|endoftext|>", | |
pad_token=None, | |
add_prefix_space=False, | |
language=None, | |
task=None, | |
predict_timestamps=False, | |
**kwargs, | |
): | |
bos_token = ( | |
AddedToken(bos_token, lstrip=False, rstrip=False, normalized=False, special=True) | |
if isinstance(bos_token, str) | |
else bos_token | |
) | |
eos_token = ( | |
AddedToken(eos_token, lstrip=False, rstrip=False, normalized=False, special=True) | |
if isinstance(eos_token, str) | |
else eos_token | |
) | |
unk_token = ( | |
AddedToken(unk_token, lstrip=False, rstrip=False, normalized=False, special=True) | |
if isinstance(unk_token, str) | |
else unk_token | |
) | |
pad_token = ( | |
AddedToken(pad_token, lstrip=False, rstrip=False, normalized=False, special=True) | |
if isinstance(pad_token, str) | |
else pad_token | |
) | |
with open(vocab_file, encoding="utf-8") as vocab_handle: | |
self.encoder = json.load(vocab_handle) | |
self.decoder = {v: k for k, v in self.encoder.items()} | |
self.errors = errors # how to handle errors in decoding | |
self.byte_encoder = bytes_to_unicode() | |
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} | |
with open(merges_file, encoding="utf-8") as merges_handle: | |
bpe_merges = merges_handle.read().split("\n")[1:-1] | |
bpe_merges = [tuple(merge.split()) for merge in bpe_merges] | |
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) | |
self.cache = {} | |
self.add_prefix_space = add_prefix_space | |
if normalizer_file is not None: | |
with open(normalizer_file, encoding="utf-8") as vocab_handle: | |
self.english_spelling_normalizer = json.load(vocab_handle) | |
else: | |
self.english_spelling_normalizer = None | |
# Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions | |
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") | |
self.timestamp_pat = re.compile(r"<\|(\d+\.\d+)\|>") | |
self.language = language | |
super().__init__( | |
errors=errors, | |
unk_token=unk_token, | |
bos_token=bos_token, | |
eos_token=eos_token, | |
pad_token=pad_token, | |
add_prefix_space=add_prefix_space, | |
**kwargs, | |
) | |
self.task = task | |
self.predict_timestamps = predict_timestamps | |
def vocab_size(self) -> int: | |
return len(self.encoder) | |
def get_vocab(self): | |
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} | |
vocab.update(self.added_tokens_encoder) | |
return vocab | |
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe with GPT2 -> Whisper | |
def bpe(self, token): | |
if token in self.cache: | |
return self.cache[token] | |
word = tuple(token) | |
pairs = get_pairs(word) | |
if not pairs: | |
return token | |
while True: | |
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) | |
if bigram not in self.bpe_ranks: | |
break | |
first, second = bigram | |
new_word = [] | |
i = 0 | |
while i < len(word): | |
try: | |
j = word.index(first, i) | |
except ValueError: | |
new_word.extend(word[i:]) | |
break | |
else: | |
new_word.extend(word[i:j]) | |
i = j | |
if word[i] == first and i < len(word) - 1 and word[i + 1] == second: | |
new_word.append(first + second) | |
i += 2 | |
else: | |
new_word.append(word[i]) | |
i += 1 | |
new_word = tuple(new_word) | |
word = new_word | |
if len(word) == 1: | |
break | |
else: | |
pairs = get_pairs(word) | |
word = " ".join(word) | |
self.cache[token] = word | |
return word | |
def set_prefix_tokens(self, language: str = None, task: str = None, predict_timestamps: bool = None): | |
""" | |
Override the prefix tokens appended to the start of the label sequence. This method can be used standalone to | |
update the prefix tokens as required when fine-tuning. Example: | |
```python | |
>>> # instantiate the tokenizer and set the prefix token to Spanish | |
>>> tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="spanish") | |
>>> # now switch the prefix token from Spanish to French | |
>>> tokenizer.set_prefix_tokens(language="french") | |
``` | |
Args: | |
language (`str`, *optional*, defaults to `None`): | |
The language of the transcription text. | |
task (`str`, *optional*, defaults to `None`): | |
Task identifier to append at the start of sequence (if any). | |
predict_timestamps (`bool`, *optional*, defaults to `None`): | |
Whether to omit the `<|notimestamps|>` token at the start of the sequence. | |
""" | |
self.language = language if language is not None else self.language | |
self.task = task if task is not None else self.task | |
self.predict_timestamps = predict_timestamps if predict_timestamps is not None else self.predict_timestamps | |
def prefix_tokens(self) -> List[int]: | |
bos_token_id = self.convert_tokens_to_ids("<|startoftranscript|>") | |
translate_token_id = self.convert_tokens_to_ids("<|translate|>") | |
transcribe_token_id = self.convert_tokens_to_ids("<|transcribe|>") | |
notimestamps_token_id = self.convert_tokens_to_ids("<|notimestamps|>") | |
langs = tuple(LANGUAGES.keys()) | |
if self.language is not None: | |
self.language = self.language.lower() | |
if self.language in TO_LANGUAGE_CODE: | |
language_id = TO_LANGUAGE_CODE[self.language] | |
elif self.language in TO_LANGUAGE_CODE.values(): | |
language_id = self.language | |
else: | |
is_language_code = len(self.language) == 2 | |
raise ValueError( | |
f"Unsupported language: {self.language}. Language should be one of:" | |
f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}." | |
) | |
if self.task is not None: | |
if self.task not in TASK_IDS: | |
raise ValueError(f"Unsupported task: {self.task}. Task should be in: {TASK_IDS}") | |
bos_sequence = [bos_token_id] | |
if self.language is not None: | |
bos_sequence.append(bos_token_id + 1 + langs.index(language_id)) | |
if self.task is not None: | |
bos_sequence.append(transcribe_token_id if self.task == "transcribe" else translate_token_id) | |
if not self.predict_timestamps: | |
bos_sequence.append(notimestamps_token_id) | |
return bos_sequence | |
# Copied from transformers.models.speech_to_text.tokenization_speech_to_text.Speech2TextTokenizer.build_inputs_with_special_tokens | |
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]: | |
"""Build model inputs from a sequence by appending eos_token_id.""" | |
if token_ids_1 is None: | |
return self.prefix_tokens + token_ids_0 + [self.eos_token_id] | |
# We don't expect to process pairs, but leave the pair logic for API consistency | |
return self.prefix_tokens + token_ids_0 + token_ids_1 + [self.eos_token_id] | |
# Copied from transformers.models.speech_to_text.tokenization_speech_to_text.Speech2TextTokenizer.get_special_tokens_mask | |
def get_special_tokens_mask( | |
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False | |
) -> List[int]: | |
""" | |
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding | |
special tokens using the tokenizer `prepare_for_model` method. | |
Args: | |
token_ids_0 (`List[int]`): | |
List of IDs. | |
token_ids_1 (`List[int]`, *optional*): | |
Optional second list of IDs for sequence pairs. | |
already_has_special_tokens (`bool`, *optional*, defaults to `False`): | |
Whether or not the token list is already formatted with special tokens for the model. | |
Returns: | |
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. | |
""" | |
if already_has_special_tokens: | |
return super().get_special_tokens_mask( | |
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True | |
) | |
prefix_ones = [1] * len(self.prefix_tokens) | |
suffix_ones = [1] | |
if token_ids_1 is None: | |
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones | |
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones | |
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize with GPT2 -> Whisper | |
def _tokenize(self, text): | |
"""Tokenize a string.""" | |
bpe_tokens = [] | |
for token in re.findall(self.pat, text): | |
token = "".join( | |
self.byte_encoder[b] for b in token.encode("utf-8") | |
) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case) | |
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" ")) | |
return bpe_tokens | |
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_token_to_id with GPT2 -> Whisper | |
def _convert_token_to_id(self, token): | |
"""Converts a token (str) in an id using the vocab.""" | |
return self.encoder.get(token, self.encoder.get(self.unk_token)) | |
def _convert_id_to_token(self, index): | |
""" | |
Converts an index (integer) in a token (str) using the vocab. Whisper's base tokenizer always decodes OOV | |
tokens as "", thus we do not use the `unk_token` here. | |
""" | |
return self.decoder.get(index, "") | |
def _normalize(self, text): | |
warnings.warn( | |
"The private method `_normalize` is deprecated and will be removed in v5 of Transformers." | |
"You can normalize an input string using the Whisper English normalizer using the `normalize` method." | |
) | |
return self.normalize(text) | |
def _basic_normalize(self, text, remove_diacritics=False): | |
warnings.warn( | |
"The private method `_basic_normalize` is deprecated and will be removed in v5 of Transformers." | |
"You can normalize an input string using the Whisper basic normalizer using the `basic_normalize` method." | |
) | |
return self.basic_normalize(text, remove_diacritics=remove_diacritics) | |
def normalize(self, text): | |
""" | |
Normalize a given string using the `EnglishTextNormalizer` class, which preforms commons transformation on | |
english text. | |
""" | |
normalizer = EnglishTextNormalizer(self.english_spelling_normalizer) | |
return normalizer(text) | |
def basic_normalize(text, remove_diacritics=False): | |
""" | |
Normalize a given string using the `BasicTextNormalizer` class, which preforms commons transformation on | |
multilingual text. | |
""" | |
normalizer = BasicTextNormalizer(remove_diacritics=remove_diacritics) | |
return normalizer(text) | |
def _decode_with_timestamps( | |
self, token_ids, skip_special_tokens=False, time_precision=0.02, segment_size=1500 | |
) -> str: | |
""" | |
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes | |
given tokens with timestamps tokens annotated, e.g. "<|1.08|>". | |
""" | |
timestamp_begin = self.all_special_ids[-1] + 1 | |
outputs = [[]] | |
cur_max_timestamp = 0.0 | |
prev_segments_len = 0.0 | |
penultimate_timestamp = 0.0 | |
for i, token in enumerate(token_ids): | |
if token >= timestamp_begin: | |
timestamp = float((token - timestamp_begin) * time_precision) | |
if timestamp < cur_max_timestamp: | |
# next segment has started | |
last_was_single_ending = i >= 2 and not ( | |
token_ids[i - 1] >= timestamp_begin and token_ids[i - 2] >= timestamp_begin | |
) | |
if last_was_single_ending: | |
prev_segments_len += time_precision * segment_size | |
else: | |
cur_max_timestamp = penultimate_timestamp | |
prev_segments_len += penultimate_timestamp | |
outputs = outputs[:-2] | |
penultimate_timestamp = cur_max_timestamp | |
cur_max_timestamp = timestamp | |
outputs.append(f"<|{(timestamp + prev_segments_len):.2f}|>") | |
outputs.append([]) | |
else: | |
outputs[-1].append(token) | |
outputs = [ | |
s if isinstance(s, str) else self.decode(s, skip_special_tokens=skip_special_tokens) for s in outputs | |
] | |
return "".join(outputs) | |
def _compute_offsets(self, token_ids, time_precision=0.02, segment_size=1500): | |
""" | |
Compute offsets for a given tokenized input | |
Args: | |
token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`): | |
List of tokenized input ids. Can be obtained using the `__call__` method. | |
time_precision (`float`, *optional*, defaults to 0.02): | |
The time ratio to convert from token to time. | |
segment_size (`int`, *optional*, defaults to 1500): | |
The number of features in the input mel spectrogram. | |
""" | |
offsets = [] | |
# ensure torch tensor of token ids is placed on cpu | |
if "torch" in str(type(token_ids)) and (hasattr(token_ids, "cpu") and callable(token_ids.cpu)): | |
token_ids = token_ids.cpu() | |
token_ids = np.array(token_ids) | |
if token_ids.shape[0] > 1 and len(token_ids.shape) > 1: | |
raise ValueError("Can only process a single input at a time") | |
timestamp_begin = self.all_special_ids[-1] + 1 | |
timestamp_tokens = token_ids >= timestamp_begin | |
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1 | |
if consecutive.shape[0] == 0 and timestamp_tokens.sum() <= 1: | |
# either there are no timestamps or there are no consecutive ones | |
return [] | |
elif np.where(timestamp_tokens)[0][-1] + 1 not in consecutive: | |
# we add the final timestamp if it is not already in the list | |
consecutive = np.append(consecutive, np.where(timestamp_tokens)[0][-1] + 1) | |
last_slice = np.where(timestamp_tokens)[0][0] | |
cur_max_timestamp = 0 | |
prev_segments_len = 0 | |
for current_slice in consecutive: | |
sliced_tokens = token_ids[last_slice:current_slice] | |
if len(sliced_tokens) > 1: | |
start_timestamp_position = sliced_tokens[0].item() - timestamp_begin | |
end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin | |
if start_timestamp_position < cur_max_timestamp: | |
# next segment has started | |
is_single_ending = last_slice >= 2 and not ( | |
token_ids[last_slice - 2] >= timestamp_begin and token_ids[last_slice - 1] >= timestamp_begin | |
) | |
if is_single_ending: | |
prev_segments_len += segment_size | |
else: | |
prev_segments_len += cur_max_timestamp | |
cur_max_timestamp = end_timestamp_position | |
# strip timestamp tokens from the text output | |
sliced_tokens = self._preprocess_token_ids(sliced_tokens) | |
text = self._decode(sliced_tokens) | |
text = self._filter_timestamp_ids(text) | |
offsets.append( | |
{ | |
"text": text, | |
"timestamp": ( | |
start_timestamp_position * time_precision + prev_segments_len * time_precision, | |
end_timestamp_position * time_precision + prev_segments_len * time_precision, | |
), | |
} | |
) | |
last_slice = current_slice | |
return offsets | |
def timestamp_ids(self, time_precision=0.02): | |
""" | |
Compute the timestamp token ids for a given precision and save to least-recently used (LRU) cache. | |
Args: | |
time_precision (`float`, *optional*, defaults to 0.02): | |
The time ratio to convert from token to time. | |
""" | |
return self.convert_tokens_to_ids([("<|%.2f|>" % (i * time_precision)) for i in range(1500 + 1)]) | |
def _preprocess_token_ids(self, token_ids, skip_special_tokens: bool = False): | |
""" | |
Pre-process the token ids for decoding by removing the prompt tokens ids and timestamp token ids. | |
Args: | |
token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`): | |
List of tokenized input ids. Typically, obtained using the `__call__` method of the tokenizer. | |
skip_special_tokens (`bool`, *optional*, defaults to `False`): | |
Whether or not to remove special tokens from the token ids. If `True`, the prompt token ids will be | |
removed. | |
""" | |
if skip_special_tokens: | |
prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>") | |
decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>") | |
token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id) | |
return token_ids | |
def _filter_timestamp_ids(self, token_ids): | |
return re.sub(self.timestamp_pat, "", token_ids) | |
def decode( | |
self, | |
token_ids, | |
skip_special_tokens: bool = False, | |
clean_up_tokenization_spaces: bool = None, | |
output_offsets: bool = False, | |
time_precision: float = 0.02, | |
decode_with_timestamps: bool = False, | |
normalize: bool = False, | |
basic_normalize: bool = False, | |
remove_diacritics: bool = False, | |
**kwargs, | |
) -> str: | |
""" | |
Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special | |
tokens and clean up tokenization spaces. | |
Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`. | |
Args: | |
token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`): | |
List of tokenized input ids. Can be obtained using the `__call__` method. | |
skip_special_tokens (`bool`, *optional*, defaults to `False`): | |
Whether or not to remove special tokens in the decoding. Will remove the previous tokens (pre-prompt) | |
if present. | |
clean_up_tokenization_spaces (`bool`, *optional*): | |
Whether or not to clean up the tokenization spaces. If `None`, will default to | |
`self.clean_up_tokenization_spaces` (available in the `tokenizer_config`). | |
output_offsets (`bool`, *optional*, defaults to `False`): | |
Whether or not to output the offsets of the tokens. This should only be set if the model predicted | |
timestamps. If there are previous tokens (pre-prompt) to decode, they will only appear in the decoded | |
text if they contain timestamp tokens. | |
time_precision (`float`, *optional*, defaults to 0.02): | |
The time ratio to convert from token to time. | |
decode_with_timestamps (`bool`, *optional*, defaults to `False`): | |
Whether or not to decode with timestamps included in the raw text. | |
normalize (`bool`, *optional*, defaults to `False`): | |
Whether or not to apply the English text normalizer to the decoded text. Only applicable when the | |
target text is in English. Otherwise, the basic text normalizer should be applied. | |
basic_normalize (`bool`, *optional*, defaults to `False`): | |
Whether or not to apply the Basic text normalizer to the decoded text. Applicable to multilingual | |
target text. | |
remove_diacritics (`bool`, *optional*, defaults to `False`): | |
Whether or not to remove diacritics when applying the Basic text normalizer. Removing diacritics may | |
destroy information in the decoded text, hence it should be used with caution. | |
kwargs (additional keyword arguments, *optional*): | |
Will be passed to the underlying model specific decode method. | |
Returns: | |
`str`: The decoded sentence. | |
""" | |
filtered_ids = self._preprocess_token_ids( | |
token_ids, | |
skip_special_tokens=skip_special_tokens, | |
) | |
text = super().decode( | |
filtered_ids, | |
skip_special_tokens=skip_special_tokens, | |
clean_up_tokenization_spaces=clean_up_tokenization_spaces, | |
normalize=normalize, | |
basic_normalize=basic_normalize, | |
remove_diacritics=remove_diacritics, | |
**kwargs, | |
) | |
if decode_with_timestamps: | |
# legacy method to decode timestamps when not included in the tokenizer vocabulary | |
text = self._decode_with_timestamps( | |
filtered_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens | |
) | |
else: | |
text = self._filter_timestamp_ids(text) | |
# retrieve offsets | |
if output_offsets: | |
offsets = self._compute_offsets(token_ids, time_precision=time_precision) | |
return {"text": text, "offsets": offsets} | |
return text | |
def _decode( | |
self, | |
token_ids: Union[int, List[int]], | |
skip_special_tokens: bool = False, | |
normalize: bool = False, | |
basic_normalize: bool = False, | |
remove_diacritics: bool = False, | |
**kwargs, | |
) -> str: | |
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False) | |
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) | |
# To avoid mixing byte-level and unicode for byte-level BPT | |
# we need to build string separately for added tokens and byte-level tokens | |
# cf. https://github.com/huggingface/transformers/issues/1133 | |
sub_texts = [] | |
current_sub_text = [] | |
for token in filtered_tokens: | |
if skip_special_tokens and token in self.all_special_ids: | |
continue | |
if token in self.added_tokens_encoder: | |
if current_sub_text: | |
sub_texts.append(self.convert_tokens_to_string(current_sub_text)) | |
current_sub_text = [] | |
sub_texts.append(token) | |
else: | |
current_sub_text.append(token) | |
if current_sub_text: | |
sub_texts.append(self.convert_tokens_to_string(current_sub_text)) | |
text = "".join(sub_texts) | |
if normalize: | |
clean_text = self.normalize(text) | |
return clean_text | |
elif basic_normalize: | |
clean_text = self.basic_normalize(text, remove_diacritics=remove_diacritics) | |
return clean_text | |
else: | |
return text | |
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.convert_tokens_to_string with GPT2 -> Whisper | |
def convert_tokens_to_string(self, tokens): | |
"""Converts a sequence of tokens (string) in a single string.""" | |
text = "".join(tokens) | |
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) | |
return text | |
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: | |
if not os.path.isdir(save_directory): | |
logger.error(f"Vocabulary path ({save_directory}) should be a directory") | |
return | |
vocab_file = os.path.join( | |
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] | |
) | |
merge_file = os.path.join( | |
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] | |
) | |
normalizer_file = os.path.join( | |
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["normalizer_file"] | |
) | |
with open(vocab_file, "w", encoding="utf-8") as f: | |
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n") | |
index = 0 | |
with open(merge_file, "w", encoding="utf-8") as writer: | |
writer.write("#version: 0.2\n") | |
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): | |
if index != token_index: | |
logger.warning( | |
f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive." | |
" Please check that the tokenizer is not corrupted!" | |
) | |
index = token_index | |
writer.write(" ".join(bpe_tokens) + "\n") | |
index += 1 | |
if self.english_spelling_normalizer is not None: | |
with open(normalizer_file, "w", encoding="utf-8") as f: | |
f.write( | |
json.dumps(self.english_spelling_normalizer, indent=2, sort_keys=True, ensure_ascii=False) + "\n" | |
) | |
return vocab_file, merge_file, normalizer_file | |
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.prepare_for_tokenization with GPT2 -> Whisper | |
def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): | |
add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space) | |
if is_split_into_words or add_prefix_space: | |
text = " " + text | |
return (text, kwargs) | |
def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True): | |
self.set_prefix_tokens(task=task, language=language, predict_timestamps=not no_timestamps) | |
# prefix tokens are of the form: <|startoftranscript|> <|lang_id|> <|task|> <|notimestamps|> | |
# we don't want to force the bos token at position 1, as this is the starting token | |
# when we generate, so we slice the prefix tokens to: <|lang_id|> <|task|> <|notimestamps|> | |
# to get the forced tokens | |
forced_tokens = self.prefix_tokens[1:] | |
forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_tokens)] | |
return forced_decoder_ids | |
def _decode_asr(self, model_outputs, *, return_timestamps, return_language, time_precision): | |
return _decode_asr( | |
self, | |
model_outputs, | |
return_timestamps=return_timestamps, | |
return_language=return_language, | |
time_precision=time_precision, | |
) | |
def get_prompt_ids(self, text: str, return_tensors="np"): | |
"""Converts prompt text to IDs that can be passed to [`~WhisperForConditionalGeneration.generate`].""" | |
batch_encoding = self("<|startofprev|>", " " + text.strip(), add_special_tokens=False) | |
# Check for special tokens | |
prompt_text_ids = batch_encoding["input_ids"][1:] | |
special_token_id = next((x for x in prompt_text_ids if x >= self.all_special_ids[0]), None) | |
if special_token_id is not None: | |
token = self.convert_ids_to_tokens(special_token_id) | |
raise ValueError(f"Encountered text in the prompt corresponding to disallowed special token: {token}.") | |
batch_encoding.convert_to_tensors(tensor_type=return_tensors) | |
return batch_encoding["input_ids"] | |
def _strip_prompt(self, token_ids: List[int], prompt_token_id: int, decoder_start_token_id: int): | |
if not isinstance(token_ids, list): | |
token_ids = self._convert_to_list(token_ids) | |
# handle case of empty token_ids for decoding with timestamps. | |
# at this point token_ids is a list, so it is safe to use if not check. | |
if not token_ids: | |
return token_ids | |
has_prompt = token_ids[0] == prompt_token_id | |
if has_prompt: | |
if decoder_start_token_id in token_ids: | |
return token_ids[token_ids.index(decoder_start_token_id) :] | |
else: | |
return [] | |
return token_ids | |
def _convert_to_list(token_ids): | |
# convert type to ndarray if necessary | |
if hasattr(token_ids, "numpy"): | |
if "torch" in str(type(token_ids)): | |
token_ids = token_ids.cpu().numpy() | |
elif "tensorflow" in str(type(token_ids)): | |
token_ids = token_ids.numpy() | |
elif "jaxlib" in str(type(token_ids)): | |
token_ids = token_ids.tolist() | |
# now the token ids are either a numpy array, or a list of lists | |
if isinstance(token_ids, np.ndarray): | |
token_ids = token_ids.tolist() | |
return token_ids | |
def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, time_precision): | |
""" | |
Internal method meant to only be used by asr pipeline. Handles all the little quirks specific to whisper to handle | |
the various options not allowed in other seq2seq models | |
""" | |
# =========== Overview ============ | |
# - iterate over all outputs | |
# - all tokens within output | |
# - Each token can be | |
# - language token | |
# - special token | |
# - timestamp token | |
# - text token | |
# - We accumulate the text tokens. | |
# - We split on end timestamps | |
# - Lots of complexity comes from stride and timestamps | |
last_language = None | |
def new_chunk(): | |
return {"language": last_language, "timestamp": [None, None], "text": ""} | |
# Welcome to the state machine ! | |
chunks = [] | |
chunk = new_chunk() | |
time_offset = 0.0 | |
timestamp_begin = tokenizer.convert_tokens_to_ids("<|notimestamps|>") + 1 | |
previous_tokens = [] | |
previous_token_timestamps = [] | |
skip = False | |
right_stride_start = None | |
all_special_ids = set(tokenizer.all_special_ids) | |
prompt_token_id = tokenizer.convert_tokens_to_ids("<|startofprev|>") | |
decoder_start_token_id = tokenizer.convert_tokens_to_ids("<|startoftranscript|>") | |
# - iterate over all outputs | |
for chunk_id, output in enumerate(model_outputs): | |
# We can drop everything to Python list, it's going to make | |
# our lives easier | |
token_ids = output["tokens"][0].tolist() | |
# (possibly) remove the prompt from the token ids | |
token_ids = tokenizer._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id) | |
if return_timestamps == "word": | |
token_timestamps = output["token_timestamps"][0].tolist() | |
# Those keep track of timestamps within strides | |
# Which need to be skipped and resolve all tokens in a single | |
# chunk. | |
last_timestamp = None | |
first_timestamp = timestamp_begin | |
if "stride" in output: | |
chunk_len, stride_left, stride_right = output["stride"] | |
# Offset the timings to account for the other `model_outputs`. | |
time_offset -= stride_left | |
right_stride_start = chunk_len - stride_right | |
# Keeping track of timestamps within strides | |
# We're going to NOT split on those, and delay until we're | |
# out of BOTH stride. Otherwise lots of issues occur and | |
# corner cases | |
if stride_left: | |
first_timestamp = stride_left / time_precision + timestamp_begin | |
if stride_right: | |
for token in reversed(token_ids): | |
if token >= timestamp_begin: | |
# There can be several token in the right stride | |
# But the last one is ALWAYS going to be skipped | |
if ( | |
last_timestamp is not None | |
and (token - timestamp_begin) * time_precision < right_stride_start | |
): | |
break | |
last_timestamp = token | |
current_tokens = [] | |
current_token_timestamps = [] | |
# - all tokens within output | |
for i, token in enumerate(token_ids): | |
# 4 possible states for each token | |
# - 1/ Language code | |
# - 2/ all other special tokens (which we ignore) | |
# - 3/ Timestamp | |
# - 4/ Regular text | |
if token in all_special_ids: | |
# Either language code or other | |
text = tokenizer.decode([token]) | |
# Removing outer shell <|XX|> | |
text = text[2:-2] | |
language = LANGUAGES.get(text, None) | |
if language is not None: | |
# 1/ Indeed some language | |
# TODO Handle when language is different from the previous | |
# one, and we cannot use timestamped tokens to create chunks | |
if last_language and language != last_language and not return_timestamps: | |
previous_tokens.append(current_tokens) | |
resolved_tokens = _find_longest_common_sequence(previous_tokens) | |
resolved_text = tokenizer.decode(resolved_tokens) | |
chunk["text"] = resolved_text | |
chunks.append(chunk) | |
# Flush all our temporary context | |
previous_tokens = [] | |
current_tokens = [] | |
chunk = new_chunk() | |
chunk["language"] = language | |
last_language = language | |
else: | |
# 2/ This is a regular special token, ignoring it | |
pass | |
elif token >= timestamp_begin: | |
# 3/ Timestamp token | |
time = (token - timestamp_begin) * time_precision + time_offset | |
time = round(time, 2) | |
if last_timestamp and token >= last_timestamp: | |
# Whisper outputted a timestamp token, but it falls within | |
# our stride, so we're going to skip it for the time being | |
# and resolve this later | |
# Skip is necessary because timestamp tokens always come | |
# by pair, so we need to skip the next one too (which would mark the start of another chunk). | |
skip = True | |
elif skip or (previous_tokens and token < first_timestamp): | |
skip = False | |
elif chunk["timestamp"][0] is None: | |
chunk["timestamp"][0] = time | |
else: | |
# This is the end of the timestamp chunk | |
if time == chunk["timestamp"][0]: | |
# This is a bug in timestamp token output | |
# where we're taking the duplicate token | |
# as a stop where it should be a start. | |
# This is an issue in the underlying model output | |
# Let's just skip it so it becomes de-factor | |
# a start agin | |
pass | |
else: | |
chunk["timestamp"][1] = time | |
# Handling merges. | |
previous_tokens.append(current_tokens) | |
if return_timestamps == "word": | |
previous_token_timestamps.append(current_token_timestamps) | |
resolved_tokens, resolved_token_timestamps = _find_longest_common_sequence( | |
previous_tokens, previous_token_timestamps | |
) | |
resolved_text = tokenizer.decode(resolved_tokens) | |
chunk["text"] = resolved_text | |
if return_timestamps == "word": | |
chunk["words"] = _collate_word_timestamps( | |
tokenizer, resolved_tokens, resolved_token_timestamps, last_language, return_language | |
) | |
chunks.append(chunk) | |
# Flush all our temporary context | |
previous_tokens = [] | |
current_tokens = [] | |
previous_token_timestamps = [] | |
current_token_timestamps = [] | |
chunk = new_chunk() | |
else: | |
# 4/ Regular token | |
# We just append to the list of all tokens so we can handle | |
# merges later and decode into text. | |
current_tokens.append(token) | |
if return_timestamps == "word": | |
start_time = round(token_timestamps[i] + time_offset, 2) | |
if i + 1 < len(token_timestamps): | |
end_time = round(token_timestamps[i + 1] + time_offset, 2) | |
else: | |
end_time = None # should never happen | |
current_token_timestamps.append((start_time, end_time)) | |
if "stride" in output: | |
time_offset += chunk_len - stride_right | |
# Leftover tokens | |
if current_tokens: | |
previous_tokens.append(current_tokens) | |
if return_timestamps == "word": | |
previous_token_timestamps.append(current_token_timestamps) | |
elif not (any(p for p in previous_tokens)): | |
chunk = new_chunk() | |
previous_tokens = [] | |
current_tokens = [] | |
previous_token_timestamps = [] | |
current_token_timestamps = [] | |
if previous_tokens: | |
if return_timestamps: | |
logger.warning( | |
"Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. " | |
"Also make sure WhisperTimeStampLogitsProcessor was used during generation." | |
) | |
# Happens when we don't use timestamps | |
resolved_tokens, resolved_token_timestamps = _find_longest_common_sequence( | |
previous_tokens, previous_token_timestamps | |
) | |
resolved_text = tokenizer.decode(resolved_tokens) | |
chunk["text"] = resolved_text | |
if return_timestamps == "word": | |
chunk["words"] = _collate_word_timestamps( | |
tokenizer, resolved_tokens, resolved_token_timestamps, last_language, return_language | |
) | |
chunks.append(chunk) | |
# Preparing and cleaning up the pipeline output | |
full_text = "".join(chunk["text"] for chunk in chunks) | |
if return_timestamps or return_language: | |
for chunk in chunks: | |
if not return_timestamps: | |
chunk.pop("timestamp") | |
else: | |
chunk["timestamp"] = tuple(chunk["timestamp"]) | |
if not return_language: | |
chunk.pop("language") | |
if return_timestamps == "word": | |
new_chunks = [] | |
for chunk in chunks: | |
new_chunks.extend(chunk["words"]) | |
optional = {"chunks": new_chunks} | |
else: | |
optional = {"chunks": chunks} | |
else: | |
optional = {} | |
return full_text, optional | |
def _find_longest_common_sequence(sequences, token_timestamp_sequences=None): | |
# It would be much harder to do O(n) because of fault tolerance. | |
# We actually have a really good property which is that the total sequence | |
# MUST be those subsequences in order. | |
# If token_timestamp_sequences is provided, will split those sequences in | |
# exactly the same way. | |
left_sequence = sequences[0] | |
left_length = len(left_sequence) | |
total_sequence = [] | |
if token_timestamp_sequences: | |
left_token_timestamp_sequence = token_timestamp_sequences[0] | |
total_token_timestamp_sequence = [] | |
for seq_idx, right_sequence in enumerate(sequences[1:]): | |
# index = 0 | |
max_ = 0.0 | |
max_indices = (left_length, left_length, 0, 0) | |
# Here we're sliding matches | |
# [a, b, c, d] | |
# [c, d, f] | |
# = [c] == [d] | |
# | |
# [a, b, c, d] | |
# [c, d, f] | |
# = [c, d] == [c, d] | |
# | |
# | |
# [a, b, c, d] | |
# [c, d, f] | |
# | |
# = [b, c, d] == [c, d, f] | |
# | |
# [a, b, c, d] | |
# [c, d, f] | |
# | |
# [a, b, c] == [c, d, f] | |
# | |
# [a, b, c, d] | |
# [d, f] | |
# | |
# [a, b] == [d, f] | |
# | |
# [a, b, c, d] | |
# [f] | |
# | |
# [a] == [f] | |
right_length = len(right_sequence) | |
for i in range(1, left_length + right_length): | |
# epsilon to favor long perfect matches | |
eps = i / 10000.0 | |
# Slightly convoluted because we don't want out of bound indices | |
# This will be necessary for a small conflict resolution optimization | |
# later | |
left_start = max(0, left_length - i) | |
left_stop = min(left_length, left_length + right_length - i) | |
left = np.array(left_sequence[left_start:left_stop]) | |
right_start = max(0, i - left_length) | |
right_stop = min(right_length, i) | |
right = np.array(right_sequence[right_start:right_stop]) | |
# We can only match subsequences of the same size. | |
if len(left) != len(right): | |
raise RuntimeError( | |
"There is a bug within whisper `decode_asr` function, please report it. Dropping to prevent bad inference." | |
) | |
if token_timestamp_sequences: | |
# Get length of longest subsequence of tokens that match | |
# and have timestamps that are in order | |
matches = sum( | |
1 | |
for idx, elem in enumerate(left) | |
if ( | |
elem == right[idx] | |
and left_token_timestamp_sequence[left_start + idx] | |
<= token_timestamp_sequences[seq_idx + 1][right_start + idx] | |
) | |
) | |
else: | |
matches = np.sum(left == right) | |
matching = matches / i + eps | |
if matches > 1 and matching > max_: | |
max_ = matching | |
max_indices = (left_start, left_stop, right_start, right_stop) | |
(left_start, left_stop, right_start, right_stop) = max_indices | |
# This is a small conflict optimization since those sequences overlap | |
# in audio. | |
# We're going to give more confidence to the left sequence | |
# for the left of the overlap, | |
# and to the right of the sequence, for the right of the overlap | |
left_mid = (left_stop + left_start) // 2 | |
right_mid = (right_stop + right_start) // 2 | |
total_sequence.extend(left_sequence[:left_mid]) | |
left_sequence = right_sequence[right_mid:] | |
left_length = len(left_sequence) | |
if token_timestamp_sequences: | |
total_token_timestamp_sequence.extend(left_token_timestamp_sequence[:left_mid]) | |
left_token_timestamp_sequence = token_timestamp_sequences[seq_idx + 1][right_mid:] | |
total_sequence.extend(left_sequence) | |
if token_timestamp_sequences is None: | |
return total_sequence | |
if len(token_timestamp_sequences) > 0: | |
total_token_timestamp_sequence.extend(left_token_timestamp_sequence) | |
return total_sequence, total_token_timestamp_sequence | |
else: | |
return total_sequence, [] | |
def _collate_word_timestamps(tokenizer, tokens, token_timestamps, language, return_language): | |
words, _, token_indices = _combine_tokens_into_words(tokenizer, tokens, language) | |
optional_language_field = {"language": language} if return_language else {} | |
timings = [ | |
{ | |
"text": word, | |
"timestamp": (token_timestamps[indices[0]][0], token_timestamps[indices[-1]][1]), | |
**optional_language_field, | |
} | |
for word, indices in zip(words, token_indices) | |
] | |
return timings | |
def _combine_tokens_into_words( | |
tokenizer, | |
tokens: List[int], | |
language: str = None, | |
prepend_punctuations: str = "\"'“¡¿([{-", | |
append_punctuations: str = "\"'.。,,!!??::”)]}、", | |
): | |
""" | |
Groups tokens by word. Returns a tuple containing a list of strings with the words, and a list of `token_id` | |
sequences with the tokens making up each word. | |
""" | |
if language is None: | |
language = tokenizer.language | |
if language is None: | |
language = "english" | |
if language in {"chinese", "japanese", "thai", "lao", "myanmar", "cantonese"}: | |
# These languages don't typically use spaces. | |
words, word_tokens, token_indices = _split_tokens_on_unicode(tokenizer, tokens) | |
else: | |
words, word_tokens, token_indices = _split_tokens_on_spaces(tokenizer, tokens) | |
_merge_punctuations(words, word_tokens, token_indices, prepend_punctuations, append_punctuations) | |
return words, word_tokens, token_indices | |
def _split_tokens_on_unicode(tokenizer, tokens: List[int]): | |
"""Combine tokens into words by splitting at any position where the tokens are decoded as valid unicode points.""" | |
decoded_full = tokenizer.decode(tokens, decode_with_timestamps=True) | |
replacement_char = "\ufffd" | |
words = [] | |
word_tokens = [] | |
token_indices = [] | |
current_tokens = [] | |
current_indices = [] | |
unicode_offset = 0 | |
for token_idx, token in enumerate(tokens): | |
current_tokens.append(token) | |
current_indices.append(token_idx) | |
decoded = tokenizer.decode(current_tokens, decode_with_timestamps=True) | |
if ( | |
replacement_char not in decoded | |
or decoded_full[unicode_offset + decoded.index(replacement_char)] == replacement_char | |
): | |
words.append(decoded) | |
word_tokens.append(current_tokens) | |
token_indices.append(current_indices) | |
current_tokens = [] | |
current_indices = [] | |
unicode_offset += len(decoded) | |
return words, word_tokens, token_indices | |
def _split_tokens_on_spaces(tokenizer, tokens: List[int]): | |
"""Combine tokens into words by splitting at whitespace and punctuation tokens.""" | |
subwords, subword_tokens_list, subword_indices_list = _split_tokens_on_unicode(tokenizer, tokens) | |
words = [] | |
word_tokens = [] | |
token_indices = [] | |
for subword, subword_tokens, subword_indices in zip(subwords, subword_tokens_list, subword_indices_list): | |
special = subword_tokens[0] >= tokenizer.eos_token_id | |
with_space = subword.startswith(" ") | |
punctuation = subword.strip() in "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" | |
if special or with_space or punctuation or len(words) == 0: | |
words.append(subword) | |
word_tokens.append(subword_tokens) | |
token_indices.append(subword_indices) | |
else: | |
words[-1] = words[-1] + subword | |
word_tokens[-1].extend(subword_tokens) | |
token_indices[-1].extend(subword_indices) | |
return words, word_tokens, token_indices | |
def _merge_punctuations(words, tokens, indices, prepended, appended): | |
"""Merges punctuation tokens with neighboring words.""" | |
# prepend punctuations | |
i = len(words) - 2 | |
j = len(words) - 1 | |
while i >= 0: | |
if words[i].startswith(" ") and words[i].strip() in prepended: | |
words[j] = words[i] + words[j] | |
tokens[j] = tokens[i] + tokens[j] | |
indices[j] = indices[i] + indices[j] | |
words[i] = "" | |
tokens[i] = [] | |
indices[i] = [] | |
else: | |
j = i | |
i -= 1 | |
# append punctuations | |
i = 0 | |
j = 1 | |
while j < len(words): | |
if not words[i].endswith(" ") and words[j] in appended: | |
words[i] += words[j] | |
tokens[i] += tokens[j] | |
indices[i] += indices[j] | |
words[j] = "" | |
tokens[j] = [] | |
indices[j] = [] | |
else: | |
i = j | |
j += 1 | |
# remove elements that are now empty | |
words[:] = [word for word in words if word] | |
tokens[:] = [token for token in tokens if token] | |
indices[:] = [idx for idx in indices if idx] | |