File size: 4,042 Bytes
46ffa30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8175a61
46ffa30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import streamlit as st
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
)

device = torch.cuda.device_count() - 1

TRANSLATION_NL_TO_EN = "translation_en_to_nl"


@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def load_model(model_name, task):
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    try:
        if not os.path.exists(".streamlit/secrets.toml"):
            raise FileNotFoundError
        access_token = st.secrets.get("babel")
    except FileNotFoundError:
        access_token = os.environ.get("HF_ACCESS_TOKEN", None)
    tokenizer = AutoTokenizer.from_pretrained(
        model_name, from_flax=True, use_auth_token=access_token
    )
    if tokenizer.pad_token is None:
        print("Adding pad_token to the tokenizer")
        tokenizer.pad_token = tokenizer.eos_token
    auto_model_class = (
        AutoModelForSeq2SeqLM if "translation" in task else AutoModelForCausalLM
    )
    model = auto_model_class.from_pretrained(
        model_name, from_flax=True, use_auth_token=access_token
    )
    if device != -1:
        model.to(f"cuda:{device}")
    return tokenizer, model


class Generator:
    def __init__(self, model_name, task, desc):
        self.model_name = model_name
        self.task = task
        self.desc = desc
        self.tokenizer = None
        self.model = None
        self.prefix = ""
        self.load()

    def load(self):
        if not self.model:
            print(f"Loading model {self.model_name}")
            self.tokenizer, self.model = load_model(self.model_name, self.task)

            try:
                if self.task in self.model.config.task_specific_params:
                    task_specific_params = self.model.config.task_specific_params[
                        self.task
                    ]
                    if "prefix" in task_specific_params:
                        self.prefix = task_specific_params["prefix"]
            except TypeError:
                pass

    def generate(self, text: str, **generate_kwargs) -> str:
        #
        # import pydevd_pycharm
        # pydevd_pycharm.settrace('10.1.0.144', port=12345, stdoutToServer=True, stderrToServer=True)
        #
        batch_encoded = self.tokenizer(
            self.prefix + text,
            max_length=generate_kwargs["max_length"],
            padding=False,
            truncation=False,
            return_tensors="pt",
        )
        if device != -1:
            batch_encoded.to(f"cuda:{device}")
        logits = self.model.generate(
            batch_encoded["input_ids"],
            attention_mask=batch_encoded["attention_mask"],
            **generate_kwargs,
        )
        decoded_preds = self.tokenizer.batch_decode(
            logits.cpu().numpy(), skip_special_tokens=False
        )
        decoded_preds = [
            pred.replace("<pad> ", "").replace("<pad>", "").replace("</s>", "")
            for pred in decoded_preds
        ]
        return decoded_preds

        # return self.pipeline(text, **generate_kwargs)

    def __str__(self):
        return self.desc


class GeneratorFactory:
    def __init__(self, generator_list):
        self.generators = []
        for g in generator_list:
            with st.spinner(text=f"Loading the model {g['desc']} ..."):
                self.add_generator(**g)

    def add_generator(self, model_name, task, desc):
        # If the generator is not yet present, add it
        if not self.get_generator(model_name=model_name, task=task, desc=desc):
            g = Generator(model_name, task, desc)
            g.load()
            self.generators.append(g)

    def get_generator(self, **kwargs):
        for g in self.generators:
            if all([g.__dict__.get(k) == v for k, v in kwargs.items()]):
                return g
        return None

    def __iter__(self):
        return iter(self.generators)

    def gpt_descs(self):
        return [g.desc for g in self.generators if g.task == TRANSLATION_NL_TO_EN]