rosetta / app.py
yhavinga's picture
Add app
46ffa30
raw
history blame
8.25 kB
import json
import os
import time
from random import randint
import psutil
import streamlit as st
import torch
from transformers import (
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoTokenizer,
pipeline,
set_seed,
)
from generator import GeneratorFactory
device = torch.cuda.device_count() - 1
TRANSLATION_NL_TO_EN = "translation_en_to_nl"
GENERATOR_LIST = [
{
"model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512l-nedd-256ccmatrix-en-nl",
"desc": "longT5 large nl8 256cc/512beta/512l en->nl",
"task": TRANSLATION_NL_TO_EN,
},
{
"model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512-nedd-en-nl",
"desc": "longT5 large nl8 512beta/512l en->nl",
"task": TRANSLATION_NL_TO_EN,
},
{
"model_name": "yhavinga/t5-small-24L-ccmatrix-multi",
"desc": "T5 small nl24 ccmatrix en->nl",
"task": TRANSLATION_NL_TO_EN,
},
]
def main():
st.set_page_config( # Alternate names: setup_page, page, layout
page_title="Babel", # String or None. Strings get appended with "β€’ Streamlit".
layout="wide", # Can be "centered" or "wide". In the future also "dashboard", etc.
initial_sidebar_state="expanded", # Can be "auto", "expanded", "collapsed"
page_icon="πŸ“š", # String, anything supported by st.image, or None.
)
if "generators" not in st.session_state:
st.session_state["generators"] = GeneratorFactory(GENERATOR_LIST)
generators = st.session_state["generators"]
with open("style.css") as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
st.sidebar.image("babel.png", width=200)
st.sidebar.markdown(
"""# Babel
Vertaal van en naar Engels"""
)
model_desc = st.sidebar.selectbox("Model", generators.gpt_descs(), index=1)
st.sidebar.title("Parameters:")
if "prompt_box" not in st.session_state:
# Text is from https://www.gutenberg.org/files/35091/35091-h/35091-h.html
st.session_state[
"prompt_box"
] = """It was a wet, gusty night and I had a lonely walk home. By taking the river road, though I hated it, I saved two miles, so I sloshed ahead trying not to think at all. Through the barbed wire fence I could see the racing river. Its black swollen body writhed along with extraordinary swiftness, breathlessly silent, only occasionally making a swishing ripple. I did not enjoy looking at it. I was somehow afraid.
And there, at the end of the river road where I swerved off, a figure stood waiting for me, motionless and enigmatic. I had to meet it or turn back.
It was a quite young girl, unknown to me, with a hood over her head, and with large unhappy eyes.
β€œMy father is very ill,” she said without a word of introduction. β€œThe nurse is frightened. Could you come in and help?”"""
st.session_state["text"] = st.text_area(
"Enter text", st.session_state.prompt_box, height=300
)
max_length = st.sidebar.number_input(
"Lengte van de tekst",
value=200,
max_value=4096,
)
no_repeat_ngram_size = st.sidebar.number_input(
"No-repeat NGram size", min_value=1, max_value=5, value=3
)
repetition_penalty = st.sidebar.number_input(
"Repetition penalty", min_value=0.0, max_value=5.0, value=1.2, step=0.1
)
num_return_sequences = st.sidebar.number_input(
"Num return sequences", min_value=1, max_value=5, value=1
)
seed_placeholder = st.sidebar.empty()
if "seed" not in st.session_state:
print(f"Session state does not contain seed")
st.session_state["seed"] = 4162549114
print(f"Seed is set to: {st.session_state['seed']}")
seed = seed_placeholder.number_input(
"Seed", min_value=0, max_value=2**32 - 1, value=st.session_state["seed"]
)
def set_random_seed():
st.session_state["seed"] = randint(0, 2**32 - 1)
seed = seed_placeholder.number_input(
"Seed", min_value=0, max_value=2**32 - 1, value=st.session_state["seed"]
)
print(f"New random seed set to: {seed}")
if st.button("Set new random seed"):
set_random_seed()
if sampling_mode := st.sidebar.selectbox(
"select a Mode", index=0, options=["Top-k Sampling", "Beam Search"]
):
if sampling_mode == "Beam Search":
num_beams = st.sidebar.number_input(
"Num beams", min_value=1, max_value=10, value=4
)
length_penalty = st.sidebar.number_input(
"Length penalty", min_value=0.0, max_value=2.0, value=1.0, step=0.1
)
params = {
"max_length": max_length,
"no_repeat_ngram_size": no_repeat_ngram_size,
"repetition_penalty": repetition_penalty,
"num_return_sequences": num_return_sequences,
"num_beams": num_beams,
"early_stopping": True,
"length_penalty": length_penalty,
}
else:
top_k = st.sidebar.number_input(
"Top K", min_value=0, max_value=100, value=50
)
top_p = st.sidebar.number_input(
"Top P", min_value=0.0, max_value=1.0, value=0.95, step=0.05
)
temperature = st.sidebar.number_input(
"Temperature", min_value=0.05, max_value=1.0, value=1.0, step=0.05
)
params = {
"max_length": max_length,
"no_repeat_ngram_size": no_repeat_ngram_size,
"repetition_penalty": repetition_penalty,
"num_return_sequences": num_return_sequences,
"do_sample": True,
"top_k": top_k,
"top_p": top_p,
"temperature": temperature,
}
st.sidebar.markdown(
"""For an explanation of the parameters, head over to the [Huggingface blog post about text generation](https://huggingface.co/blog/how-to-generate)
and the [Huggingface text generation interface doc](https://huggingface.co/transformers/main_classes/model.html?highlight=generate#transformers.generation_utils.GenerationMixin.generate).
"""
)
def estimate_time():
"""Estimate the time it takes to generate the text."""
estimate = max_length / 18
if device == -1:
## cpu
estimate = estimate * (1 + 0.7 * (num_return_sequences - 1))
if sampling_mode == "Beam Search":
estimate = estimate * (1.1 + 0.3 * (num_beams - 1))
else:
## gpu
estimate = estimate * (1 + 0.1 * (num_return_sequences - 1))
estimate = 0.5 + estimate / 5
if sampling_mode == "Beam Search":
estimate = estimate * (1.0 + 0.1 * (num_beams - 1))
return int(estimate)
if st.button("Run"):
estimate = estimate_time()
with st.spinner(
text=f"Please wait ~ {estimate} second{'s' if estimate != 1 else ''} while getting results ..."
):
memory = psutil.virtual_memory()
for generator in generators:
st.subheader(f"Result from {generator}")
set_seed(seed)
time_start = time.time()
result = generator.generate(text=st.session_state.text, **params)
time_end = time.time()
time_diff = time_end - time_start
for text in result:
st.write(text.replace("\n", " \n"))
st.write(f"--- generated in {time_diff:.2f} seconds ---")
info = f"""
---
*Memory: {memory.total / 10**9:.2f}GB, used: {memory.percent}%, available: {memory.available / 10**9:.2f}GB*
*Text generated using seed {seed}*
"""
st.write(info)
params["seed"] = seed
params["prompt"] = st.session_state.text
params["model"] = generator.model_name
params_text = json.dumps(params)
print(params_text)
st.json(params_text)
if __name__ == "__main__":
main()