File size: 5,927 Bytes
cda95dc 46ffa30 3f553b1 46ffa30 cdb537e 46ffa30 3f553b1 46ffa30 8175a61 46ffa30 3f553b1 cda95dc 3f553b1 0ed2b71 46ffa30 528bd83 46ffa30 0ed2b71 8cd0b56 0ed2b71 8cd0b56 0ed2b71 8cd0b56 0ed2b71 46ffa30 3f553b1 46ffa30 a19a543 46ffa30 3f553b1 46ffa30 3f553b1 46ffa30 8cd0b56 3f553b1 46ffa30 0ed2b71 3f553b1 528bd83 3f553b1 a19a543 8cd0b56 a19a543 0ed2b71 a19a543 528bd83 46ffa30 0ed2b71 46ffa30 528bd83 a19a543 46ffa30 bc21832 46ffa30 3f553b1 46ffa30 3f553b1 46ffa30 bc21832 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import _thread
import os
import re
import streamlit as st
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
device = torch.cuda.device_count() - 1
def get_access_token():
try:
if not os.path.exists(".streamlit/secrets.toml"):
raise FileNotFoundError
access_token = st.secrets.get("babel")
except FileNotFoundError:
access_token = os.environ.get("HF_ACCESS_TOKEN", None)
return access_token
# @st.cache(hash_funcs={_thread.RLock: lambda _: None}, suppress_st_warning=True, allow_output_mutation=True)
def load_model(model_name):
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
tokenizer = AutoTokenizer.from_pretrained(
model_name,
use_fast=("ul2" not in model_name),
use_auth_token=get_access_token(),
)
if tokenizer.pad_token is None:
print("Adding pad_token to the tokenizer")
tokenizer.pad_token = tokenizer.eos_token
for framework in [None, "flax", "tf"]:
try:
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
from_flax=(framework == "flax"),
from_tf=(framework == "tf"),
use_auth_token=get_access_token(),
)
break
except EnvironmentError:
if framework == "tf":
raise
if device != -1:
model.to(f"cuda:{device}")
return tokenizer, model
class Generator:
def __init__(self, model_name, task, desc, split_sentences):
self.model_name = model_name
self.task = task
self.desc = desc
self.split_sentences = split_sentences
self.tokenizer = None
self.model = None
self.prefix = ""
self.gen_kwargs = {
"max_length": 128,
"num_beams": 6,
"num_beam_groups": 3,
"no_repeat_ngram_size": 0,
"early_stopping": True,
"num_return_sequences": 1,
"length_penalty": 1.0,
}
self.load()
def load(self):
if not self.model:
print(f"Loading model {self.model_name}")
self.tokenizer, self.model = load_model(self.model_name)
for key in self.gen_kwargs:
if key in self.model.config.__dict__:
self.gen_kwargs[key] = self.model.config.__dict__[key]
try:
if self.task in self.model.config.task_specific_params:
task_specific_params = self.model.config.task_specific_params[
self.task
]
if "prefix" in task_specific_params:
self.prefix = task_specific_params["prefix"]
for key in self.gen_kwargs:
if key in task_specific_params:
self.gen_kwargs[key] = task_specific_params[key]
except TypeError:
pass
def generate(self, text: str, streamer=None, **generate_kwargs) -> (str, dict):
# Replace two or more newlines with a single newline in text
text = re.sub(r"\n{2,}", "\n", text)
generate_kwargs = {**self.gen_kwargs, **generate_kwargs}
# if there are newlines in the text, and the model needs line-splitting, split the text and recurse
if re.search(r"\n", text) and self.split_sentences:
lines = text.splitlines()
translated = [
self.generate(line, streamer, **generate_kwargs)[0] for line in lines
]
return "\n".join(translated), generate_kwargs
# if self.tokenizer has a newline_token attribute, replace \n with it
if hasattr(self.tokenizer, "newline_token"):
text = re.sub(r"\n", self.tokenizer.newline_token, text)
batch_encoded = self.tokenizer(
self.prefix + text,
max_length=generate_kwargs["max_length"],
padding=False,
truncation=False,
return_tensors="pt",
)
if device != -1:
batch_encoded.to(f"cuda:{device}")
logits = self.model.generate(
batch_encoded["input_ids"],
attention_mask=batch_encoded["attention_mask"],
streamer=streamer,
**generate_kwargs,
)
decoded_preds = self.tokenizer.batch_decode(
logits.cpu().numpy(), skip_special_tokens=False
)
def replace_tokens(pred):
pred = pred.replace("<pad> ", "").replace("<pad>", "").replace("</s>", "")
if hasattr(self.tokenizer, "newline_token"):
pred = pred.replace(self.tokenizer.newline_token, "\n")
return pred
decoded_preds = list(map(replace_tokens, decoded_preds))
return decoded_preds[0], generate_kwargs
def __str__(self):
return self.model_name
class GeneratorFactory:
def __init__(self, generator_list):
self.generators = []
for g in generator_list:
with st.spinner(text=f"Loading the model {g['desc']} ..."):
self.add_generator(**g)
def add_generator(self, model_name, task, desc, split_sentences):
# If the generator is not yet present, add it
if not self.get_generator(model_name=model_name, task=task, desc=desc):
g = Generator(model_name, task, desc, split_sentences)
g.load()
self.generators.append(g)
def get_generator(self, **kwargs):
for g in self.generators:
if all([g.__dict__.get(k) == v for k, v in kwargs.items()]):
return g
return None
def __iter__(self):
return iter(self.generators)
def filter(self, **kwargs):
return [
g
for g in self.generators
if all([g.__dict__.get(k) == v for k, v in kwargs.items()])
]
|