File size: 4,042 Bytes
46ffa30 8175a61 46ffa30 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import os
import streamlit as st
import torch
from transformers import (
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoTokenizer,
)
device = torch.cuda.device_count() - 1
TRANSLATION_NL_TO_EN = "translation_en_to_nl"
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def load_model(model_name, task):
os.environ["TOKENIZERS_PARALLELISM"] = "false"
try:
if not os.path.exists(".streamlit/secrets.toml"):
raise FileNotFoundError
access_token = st.secrets.get("babel")
except FileNotFoundError:
access_token = os.environ.get("HF_ACCESS_TOKEN", None)
tokenizer = AutoTokenizer.from_pretrained(
model_name, from_flax=True, use_auth_token=access_token
)
if tokenizer.pad_token is None:
print("Adding pad_token to the tokenizer")
tokenizer.pad_token = tokenizer.eos_token
auto_model_class = (
AutoModelForSeq2SeqLM if "translation" in task else AutoModelForCausalLM
)
model = auto_model_class.from_pretrained(
model_name, from_flax=True, use_auth_token=access_token
)
if device != -1:
model.to(f"cuda:{device}")
return tokenizer, model
class Generator:
def __init__(self, model_name, task, desc):
self.model_name = model_name
self.task = task
self.desc = desc
self.tokenizer = None
self.model = None
self.prefix = ""
self.load()
def load(self):
if not self.model:
print(f"Loading model {self.model_name}")
self.tokenizer, self.model = load_model(self.model_name, self.task)
try:
if self.task in self.model.config.task_specific_params:
task_specific_params = self.model.config.task_specific_params[
self.task
]
if "prefix" in task_specific_params:
self.prefix = task_specific_params["prefix"]
except TypeError:
pass
def generate(self, text: str, **generate_kwargs) -> str:
#
# import pydevd_pycharm
# pydevd_pycharm.settrace('10.1.0.144', port=12345, stdoutToServer=True, stderrToServer=True)
#
batch_encoded = self.tokenizer(
self.prefix + text,
max_length=generate_kwargs["max_length"],
padding=False,
truncation=False,
return_tensors="pt",
)
if device != -1:
batch_encoded.to(f"cuda:{device}")
logits = self.model.generate(
batch_encoded["input_ids"],
attention_mask=batch_encoded["attention_mask"],
**generate_kwargs,
)
decoded_preds = self.tokenizer.batch_decode(
logits.cpu().numpy(), skip_special_tokens=False
)
decoded_preds = [
pred.replace("<pad> ", "").replace("<pad>", "").replace("</s>", "")
for pred in decoded_preds
]
return decoded_preds
# return self.pipeline(text, **generate_kwargs)
def __str__(self):
return self.desc
class GeneratorFactory:
def __init__(self, generator_list):
self.generators = []
for g in generator_list:
with st.spinner(text=f"Loading the model {g['desc']} ..."):
self.add_generator(**g)
def add_generator(self, model_name, task, desc):
# If the generator is not yet present, add it
if not self.get_generator(model_name=model_name, task=task, desc=desc):
g = Generator(model_name, task, desc)
g.load()
self.generators.append(g)
def get_generator(self, **kwargs):
for g in self.generators:
if all([g.__dict__.get(k) == v for k, v in kwargs.items()]):
return g
return None
def __iter__(self):
return iter(self.generators)
def gpt_descs(self):
return [g.desc for g in self.generators if g.task == TRANSLATION_NL_TO_EN]
|