File size: 5,927 Bytes
cda95dc
46ffa30
3f553b1
 
46ffa30
 
cdb537e
46ffa30
 
 
 
3f553b1
46ffa30
 
 
8175a61
46ffa30
 
3f553b1
 
 
cda95dc
3f553b1
0ed2b71
46ffa30
528bd83
 
 
46ffa30
 
 
 
0ed2b71
8cd0b56
 
0ed2b71
 
 
 
8cd0b56
0ed2b71
8cd0b56
0ed2b71
 
46ffa30
 
 
 
 
 
3f553b1
46ffa30
 
 
a19a543
46ffa30
 
 
3f553b1
 
 
 
 
 
 
 
 
46ffa30
 
 
 
 
3f553b1
 
 
 
 
46ffa30
 
 
 
 
8cd0b56
 
3f553b1
 
 
46ffa30
 
 
0ed2b71
3f553b1
 
528bd83
3f553b1
a19a543
8cd0b56
a19a543
 
0ed2b71
 
 
a19a543
 
528bd83
 
 
 
46ffa30
 
 
 
 
 
 
 
 
 
 
 
0ed2b71
46ffa30
 
 
 
 
528bd83
 
 
 
 
 
 
 
a19a543
46ffa30
 
bc21832
46ffa30
 
 
 
 
 
 
 
 
3f553b1
46ffa30
 
3f553b1
46ffa30
 
 
 
 
 
 
 
 
 
 
 
bc21832
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import _thread
import os
import re

import streamlit as st
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

device = torch.cuda.device_count() - 1


def get_access_token():
    try:
        if not os.path.exists(".streamlit/secrets.toml"):
            raise FileNotFoundError
        access_token = st.secrets.get("babel")
    except FileNotFoundError:
        access_token = os.environ.get("HF_ACCESS_TOKEN", None)
    return access_token


# @st.cache(hash_funcs={_thread.RLock: lambda _: None}, suppress_st_warning=True, allow_output_mutation=True)
def load_model(model_name):
    os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        use_fast=("ul2" not in model_name),
        use_auth_token=get_access_token(),
    )
    if tokenizer.pad_token is None:
        print("Adding pad_token to the tokenizer")
        tokenizer.pad_token = tokenizer.eos_token
    for framework in [None, "flax", "tf"]:
        try:
            model = AutoModelForSeq2SeqLM.from_pretrained(
                model_name,
                from_flax=(framework == "flax"),
                from_tf=(framework == "tf"),
                use_auth_token=get_access_token(),
            )
            break
        except EnvironmentError:
            if framework == "tf":
                raise
    if device != -1:
        model.to(f"cuda:{device}")
    return tokenizer, model


class Generator:
    def __init__(self, model_name, task, desc, split_sentences):
        self.model_name = model_name
        self.task = task
        self.desc = desc
        self.split_sentences = split_sentences
        self.tokenizer = None
        self.model = None
        self.prefix = ""
        self.gen_kwargs = {
            "max_length": 128,
            "num_beams": 6,
            "num_beam_groups": 3,
            "no_repeat_ngram_size": 0,
            "early_stopping": True,
            "num_return_sequences": 1,
            "length_penalty": 1.0,
        }
        self.load()

    def load(self):
        if not self.model:
            print(f"Loading model {self.model_name}")
            self.tokenizer, self.model = load_model(self.model_name)

            for key in self.gen_kwargs:
                if key in self.model.config.__dict__:
                    self.gen_kwargs[key] = self.model.config.__dict__[key]
            try:
                if self.task in self.model.config.task_specific_params:
                    task_specific_params = self.model.config.task_specific_params[
                        self.task
                    ]
                    if "prefix" in task_specific_params:
                        self.prefix = task_specific_params["prefix"]
                    for key in self.gen_kwargs:
                        if key in task_specific_params:
                            self.gen_kwargs[key] = task_specific_params[key]
            except TypeError:
                pass

    def generate(self, text: str, streamer=None, **generate_kwargs) -> (str, dict):
        # Replace two or more newlines with a single newline in text
        text = re.sub(r"\n{2,}", "\n", text)

        generate_kwargs = {**self.gen_kwargs, **generate_kwargs}

        # if there are newlines in the text, and the model needs line-splitting, split the text and recurse
        if re.search(r"\n", text) and self.split_sentences:
            lines = text.splitlines()
            translated = [
                self.generate(line, streamer, **generate_kwargs)[0] for line in lines
            ]
            return "\n".join(translated), generate_kwargs

        # if self.tokenizer has a newline_token attribute, replace \n with it
        if hasattr(self.tokenizer, "newline_token"):
            text = re.sub(r"\n", self.tokenizer.newline_token, text)

        batch_encoded = self.tokenizer(
            self.prefix + text,
            max_length=generate_kwargs["max_length"],
            padding=False,
            truncation=False,
            return_tensors="pt",
        )
        if device != -1:
            batch_encoded.to(f"cuda:{device}")
        logits = self.model.generate(
            batch_encoded["input_ids"],
            attention_mask=batch_encoded["attention_mask"],
            streamer=streamer,
            **generate_kwargs,
        )
        decoded_preds = self.tokenizer.batch_decode(
            logits.cpu().numpy(), skip_special_tokens=False
        )

        def replace_tokens(pred):
            pred = pred.replace("<pad> ", "").replace("<pad>", "").replace("</s>", "")
            if hasattr(self.tokenizer, "newline_token"):
                pred = pred.replace(self.tokenizer.newline_token, "\n")
            return pred

        decoded_preds = list(map(replace_tokens, decoded_preds))
        return decoded_preds[0], generate_kwargs

    def __str__(self):
        return self.model_name


class GeneratorFactory:
    def __init__(self, generator_list):
        self.generators = []
        for g in generator_list:
            with st.spinner(text=f"Loading the model {g['desc']} ..."):
                self.add_generator(**g)

    def add_generator(self, model_name, task, desc, split_sentences):
        # If the generator is not yet present, add it
        if not self.get_generator(model_name=model_name, task=task, desc=desc):
            g = Generator(model_name, task, desc, split_sentences)
            g.load()
            self.generators.append(g)

    def get_generator(self, **kwargs):
        for g in self.generators:
            if all([g.__dict__.get(k) == v for k, v in kwargs.items()]):
                return g
        return None

    def __iter__(self):
        return iter(self.generators)

    def filter(self, **kwargs):
        return [
            g
            for g in self.generators
            if all([g.__dict__.get(k) == v for k, v in kwargs.items()])
        ]