File size: 4,568 Bytes
46ffa30 3f553b1 46ffa30 3f553b1 46ffa30 3f553b1 46ffa30 3f553b1 46ffa30 a19a543 3f553b1 a19a543 46ffa30 3f553b1 46ffa30 3f553b1 46ffa30 3f553b1 a19a543 3f553b1 46ffa30 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import time
import torch
import psutil
import streamlit as st
from generator import GeneratorFactory
device = torch.cuda.device_count() - 1
TRANSLATION_NL_TO_EN = "translation_en_to_nl"
GENERATOR_LIST = [
{
"model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512l-nedd-256ccmatrix-en-nl",
"desc": "longT5 large nl8 256cc/512beta/512l en->nl",
"task": TRANSLATION_NL_TO_EN,
"split_sentences": False,
},
{
"model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512-nedd-en-nl",
"desc": "longT5 large nl8 512beta/512l en->nl",
"task": TRANSLATION_NL_TO_EN,
"split_sentences": False,
},
{
"model_name": "yhavinga/t5-small-24L-ccmatrix-multi",
"desc": "T5 small nl24 ccmatrix en->nl",
"task": TRANSLATION_NL_TO_EN,
"split_sentences": True,
},
]
def main():
st.set_page_config( # Alternate names: setup_page, page, layout
page_title="Babel", # String or None. Strings get appended with "โข Streamlit".
layout="wide", # Can be "centered" or "wide". In the future also "dashboard", etc.
initial_sidebar_state="expanded", # Can be "auto", "expanded", "collapsed"
page_icon="๐", # String, anything supported by st.image, or None.
)
if "generators" not in st.session_state:
st.session_state["generators"] = GeneratorFactory(GENERATOR_LIST)
generators = st.session_state["generators"]
with open("style.css") as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
st.sidebar.image("babel.png", width=200)
st.sidebar.markdown(
"""# Babel
Vertaal van en naar Engels"""
)
st.sidebar.title("Parameters:")
if "prompt_box" not in st.session_state:
# Text is from https://www.gutenberg.org/files/35091/35091-h/35091-h.html
st.session_state[
"prompt_box"
] = """It was a wet, gusty night and I had a lonely walk home. By taking the river road, though I hated it, I saved two miles, so I sloshed ahead trying not to think at all. Through the barbed wire fence I could see the racing river. Its black swollen body writhed along with extraordinary swiftness, breathlessly silent, only occasionally making a swishing ripple. I did not enjoy looking at it. I was somehow afraid.
And there, at the end of the river road where I swerved off, a figure stood waiting for me, motionless and enigmatic. I had to meet it or turn back.
It was a quite young girl, unknown to me, with a hood over her head, and with large unhappy eyes.
โMy father is very ill,โ she said without a word of introduction. โThe nurse is frightened. Could you come in and help?โ"""
st.session_state["text"] = st.text_area(
"Enter text", st.session_state.prompt_box, height=300
)
num_beams = st.sidebar.number_input("Num beams", min_value=1, max_value=10, value=1)
num_beam_groups = st.sidebar.number_input(
"Num beam groups", min_value=1, max_value=10, value=1
)
length_penalty = st.sidebar.number_input(
"Length penalty", min_value=0.0, max_value=2.0, value=1.2, step=0.1
)
st.sidebar.markdown(
"""For an explanation of the parameters, head over to the [Huggingface blog post about text generation](https://huggingface.co/blog/how-to-generate)
and the [Huggingface text generation interface doc](https://huggingface.co/transformers/main_classes/model.html?highlight=generate#transformers.generation_utils.GenerationMixin.generate).
"""
)
params = {
"num_beams": num_beams,
"num_beam_groups": num_beam_groups,
"length_penalty": length_penalty,
}
if st.button("Run"):
memory = psutil.virtual_memory()
for generator in generators:
st.markdown(f"๐งฎ **Model `{generator}`**")
time_start = time.time()
result, params_used = generator.generate(
text=st.session_state.text, **params
)
time_end = time.time()
time_diff = time_end - time_start
st.write(result.replace("\n", " \n"))
text_line = ", ".join([f"{k}={v}" for k, v in params_used.items()])
st.markdown(f" ๐ *generated in {time_diff:.2f}s, `{text_line}`*")
st.write(
f"""
---
*Memory: {memory.total / 10**9:.2f}GB, used: {memory.percent}%, available: {memory.available / 10**9:.2f}GB*
"""
)
if __name__ == "__main__":
main()
|