|
import json |
|
import os |
|
import time |
|
from random import randint |
|
|
|
import psutil |
|
import streamlit as st |
|
import torch |
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoModelForSeq2SeqLM, |
|
AutoTokenizer, |
|
pipeline, |
|
set_seed, |
|
) |
|
|
|
from generator import GeneratorFactory |
|
|
|
device = torch.cuda.device_count() - 1 |
|
|
|
TRANSLATION_NL_TO_EN = "translation_en_to_nl" |
|
|
|
GENERATOR_LIST = [ |
|
{ |
|
"model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512l-nedd-256ccmatrix-en-nl", |
|
"desc": "longT5 large nl8 256cc/512beta/512l en->nl", |
|
"task": TRANSLATION_NL_TO_EN, |
|
}, |
|
{ |
|
"model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512-nedd-en-nl", |
|
"desc": "longT5 large nl8 512beta/512l en->nl", |
|
"task": TRANSLATION_NL_TO_EN, |
|
}, |
|
{ |
|
"model_name": "yhavinga/t5-small-24L-ccmatrix-multi", |
|
"desc": "T5 small nl24 ccmatrix en->nl", |
|
"task": TRANSLATION_NL_TO_EN, |
|
}, |
|
] |
|
|
|
|
|
def main(): |
|
st.set_page_config( |
|
page_title="Babel", |
|
layout="wide", |
|
initial_sidebar_state="expanded", |
|
page_icon="π", |
|
) |
|
|
|
if "generators" not in st.session_state: |
|
st.session_state["generators"] = GeneratorFactory(GENERATOR_LIST) |
|
|
|
generators = st.session_state["generators"] |
|
|
|
with open("style.css") as f: |
|
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) |
|
|
|
st.sidebar.image("babel.png", width=200) |
|
st.sidebar.markdown( |
|
"""# Babel |
|
Vertaal van en naar Engels""" |
|
) |
|
model_desc = st.sidebar.selectbox("Model", generators.gpt_descs(), index=1) |
|
st.sidebar.title("Parameters:") |
|
if "prompt_box" not in st.session_state: |
|
|
|
st.session_state[ |
|
"prompt_box" |
|
] = """It was a wet, gusty night and I had a lonely walk home. By taking the river road, though I hated it, I saved two miles, so I sloshed ahead trying not to think at all. Through the barbed wire fence I could see the racing river. Its black swollen body writhed along with extraordinary swiftness, breathlessly silent, only occasionally making a swishing ripple. I did not enjoy looking at it. I was somehow afraid. |
|
|
|
And there, at the end of the river road where I swerved off, a figure stood waiting for me, motionless and enigmatic. I had to meet it or turn back. |
|
|
|
It was a quite young girl, unknown to me, with a hood over her head, and with large unhappy eyes. |
|
|
|
βMy father is very ill,β she said without a word of introduction. βThe nurse is frightened. Could you come in and help?β""" |
|
st.session_state["text"] = st.text_area( |
|
"Enter text", st.session_state.prompt_box, height=300 |
|
) |
|
max_length = st.sidebar.number_input( |
|
"Lengte van de tekst", |
|
value=200, |
|
max_value=4096, |
|
) |
|
no_repeat_ngram_size = st.sidebar.number_input( |
|
"No-repeat NGram size", min_value=1, max_value=5, value=3 |
|
) |
|
repetition_penalty = st.sidebar.number_input( |
|
"Repetition penalty", min_value=0.0, max_value=5.0, value=1.2, step=0.1 |
|
) |
|
num_return_sequences = st.sidebar.number_input( |
|
"Num return sequences", min_value=1, max_value=5, value=1 |
|
) |
|
seed_placeholder = st.sidebar.empty() |
|
if "seed" not in st.session_state: |
|
print(f"Session state does not contain seed") |
|
st.session_state["seed"] = 4162549114 |
|
print(f"Seed is set to: {st.session_state['seed']}") |
|
|
|
seed = seed_placeholder.number_input( |
|
"Seed", min_value=0, max_value=2**32 - 1, value=st.session_state["seed"] |
|
) |
|
|
|
def set_random_seed(): |
|
st.session_state["seed"] = randint(0, 2**32 - 1) |
|
seed = seed_placeholder.number_input( |
|
"Seed", min_value=0, max_value=2**32 - 1, value=st.session_state["seed"] |
|
) |
|
print(f"New random seed set to: {seed}") |
|
|
|
if st.button("Set new random seed"): |
|
set_random_seed() |
|
|
|
if sampling_mode := st.sidebar.selectbox( |
|
"select a Mode", index=0, options=["Top-k Sampling", "Beam Search"] |
|
): |
|
if sampling_mode == "Beam Search": |
|
num_beams = st.sidebar.number_input( |
|
"Num beams", min_value=1, max_value=10, value=4 |
|
) |
|
length_penalty = st.sidebar.number_input( |
|
"Length penalty", min_value=0.0, max_value=2.0, value=1.0, step=0.1 |
|
) |
|
params = { |
|
"max_length": max_length, |
|
"no_repeat_ngram_size": no_repeat_ngram_size, |
|
"repetition_penalty": repetition_penalty, |
|
"num_return_sequences": num_return_sequences, |
|
"num_beams": num_beams, |
|
"early_stopping": True, |
|
"length_penalty": length_penalty, |
|
} |
|
else: |
|
top_k = st.sidebar.number_input( |
|
"Top K", min_value=0, max_value=100, value=50 |
|
) |
|
top_p = st.sidebar.number_input( |
|
"Top P", min_value=0.0, max_value=1.0, value=0.95, step=0.05 |
|
) |
|
temperature = st.sidebar.number_input( |
|
"Temperature", min_value=0.05, max_value=1.0, value=1.0, step=0.05 |
|
) |
|
params = { |
|
"max_length": max_length, |
|
"no_repeat_ngram_size": no_repeat_ngram_size, |
|
"repetition_penalty": repetition_penalty, |
|
"num_return_sequences": num_return_sequences, |
|
"do_sample": True, |
|
"top_k": top_k, |
|
"top_p": top_p, |
|
"temperature": temperature, |
|
} |
|
|
|
st.sidebar.markdown( |
|
"""For an explanation of the parameters, head over to the [Huggingface blog post about text generation](https://huggingface.co/blog/how-to-generate) |
|
and the [Huggingface text generation interface doc](https://huggingface.co/transformers/main_classes/model.html?highlight=generate#transformers.generation_utils.GenerationMixin.generate). |
|
""" |
|
) |
|
|
|
def estimate_time(): |
|
"""Estimate the time it takes to generate the text.""" |
|
estimate = max_length / 18 |
|
if device == -1: |
|
|
|
estimate = estimate * (1 + 0.7 * (num_return_sequences - 1)) |
|
if sampling_mode == "Beam Search": |
|
estimate = estimate * (1.1 + 0.3 * (num_beams - 1)) |
|
else: |
|
|
|
estimate = estimate * (1 + 0.1 * (num_return_sequences - 1)) |
|
estimate = 0.5 + estimate / 5 |
|
if sampling_mode == "Beam Search": |
|
estimate = estimate * (1.0 + 0.1 * (num_beams - 1)) |
|
return int(estimate) |
|
|
|
if st.button("Run"): |
|
estimate = estimate_time() |
|
|
|
with st.spinner( |
|
text=f"Please wait ~ {estimate} second{'s' if estimate != 1 else ''} while getting results ..." |
|
): |
|
memory = psutil.virtual_memory() |
|
|
|
for generator in generators: |
|
st.subheader(f"Result from {generator}") |
|
set_seed(seed) |
|
time_start = time.time() |
|
result = generator.generate(text=st.session_state.text, **params) |
|
time_end = time.time() |
|
time_diff = time_end - time_start |
|
|
|
for text in result: |
|
st.write(text.replace("\n", " \n")) |
|
st.write(f"--- generated in {time_diff:.2f} seconds ---") |
|
|
|
info = f""" |
|
--- |
|
*Memory: {memory.total / 10**9:.2f}GB, used: {memory.percent}%, available: {memory.available / 10**9:.2f}GB* |
|
*Text generated using seed {seed}* |
|
""" |
|
st.write(info) |
|
|
|
params["seed"] = seed |
|
params["prompt"] = st.session_state.text |
|
params["model"] = generator.model_name |
|
params_text = json.dumps(params) |
|
print(params_text) |
|
st.json(params_text) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|