File size: 2,250 Bytes
abca9bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from functools import partial
import numpy as np


def decode(id_to_something, tokenizer=None, data_args=None):
    decode_fn = None
    switch_case = None
    elem = next(iter(id_to_something.values()))
    if isinstance(elem, str):
        switch_case = -1
        decode_fn = lambda text: text.strip()
    elif isinstance(elem, list) and not isinstance(elem[0], int):
        if isinstance(elem[0], str):
            switch_case = 0
            decode_fn = lambda texts: [text.strip() for text in texts]
        else:
            switch_case = 1
            decode_fn = lambda token_ids_list: [
                text.strip()
                for text in partial(
                    tokenizer.batch_decode, skip_special_tokens=True, clean_up_tokenization_spaces=True
                )(token_ids_list)
            ]
    else:
        switch_case = 2
        decode_fn = lambda token_ids: partial(
            tokenizer.decode, skip_special_tokens=True, clean_up_tokenization_spaces=True
        )(token_ids).strip()

    id_to_text = {}
    for id_, something in id_to_something.items():
        if switch_case == -1 or switch_case == 0:
            obj_to_decode = something
        else:
            if data_args is None:
                data_args = {}
            if not isinstance(data_args, dict):
                data_args = vars(data_args)
            if data_args.get("ignore_pad_token_for_loss", True):
                # Replace -100 in the token_ids as we can't decode them.
                if switch_case == 1:
                    token_ids_list = something
                    for i in range(len(token_ids_list)):
                        token_ids_list[i] = _replace_padding(token_ids_list[i], tokenizer.pad_token_id)
                    obj_to_decode = token_ids_list
                elif switch_case == 2:
                    token_ids = something
                    token_ids = _replace_padding(token_ids, tokenizer.pad_token_id)
                    obj_to_decode = token_ids
            else:
                obj_to_decode = something

        id_to_text[id_] = decode_fn(obj_to_decode)

    return id_to_text


def _replace_padding(token_ids: np.array, pad_token_id):
    return np.where(token_ids != -100, token_ids, pad_token_id)