lhzstar
new commits
abca9bf
raw
history blame
2.25 kB
from functools import partial
import numpy as np
def decode(id_to_something, tokenizer=None, data_args=None):
decode_fn = None
switch_case = None
elem = next(iter(id_to_something.values()))
if isinstance(elem, str):
switch_case = -1
decode_fn = lambda text: text.strip()
elif isinstance(elem, list) and not isinstance(elem[0], int):
if isinstance(elem[0], str):
switch_case = 0
decode_fn = lambda texts: [text.strip() for text in texts]
else:
switch_case = 1
decode_fn = lambda token_ids_list: [
text.strip()
for text in partial(
tokenizer.batch_decode, skip_special_tokens=True, clean_up_tokenization_spaces=True
)(token_ids_list)
]
else:
switch_case = 2
decode_fn = lambda token_ids: partial(
tokenizer.decode, skip_special_tokens=True, clean_up_tokenization_spaces=True
)(token_ids).strip()
id_to_text = {}
for id_, something in id_to_something.items():
if switch_case == -1 or switch_case == 0:
obj_to_decode = something
else:
if data_args is None:
data_args = {}
if not isinstance(data_args, dict):
data_args = vars(data_args)
if data_args.get("ignore_pad_token_for_loss", True):
# Replace -100 in the token_ids as we can't decode them.
if switch_case == 1:
token_ids_list = something
for i in range(len(token_ids_list)):
token_ids_list[i] = _replace_padding(token_ids_list[i], tokenizer.pad_token_id)
obj_to_decode = token_ids_list
elif switch_case == 2:
token_ids = something
token_ids = _replace_padding(token_ids, tokenizer.pad_token_id)
obj_to_decode = token_ids
else:
obj_to_decode = something
id_to_text[id_] = decode_fn(obj_to_decode)
return id_to_text
def _replace_padding(token_ids: np.array, pad_token_id):
return np.where(token_ids != -100, token_ids, pad_token_id)