|
import time |
|
|
|
import psutil |
|
import streamlit as st |
|
import torch |
|
from langdetect import detect |
|
from transformers import TextIteratorStreamer |
|
|
|
from default_texts import default_texts |
|
from generator import GeneratorFactory |
|
|
|
device = torch.cuda.device_count() - 1 |
|
|
|
TRANSLATION_EN_TO_NL = "translation_en_to_nl" |
|
TRANSLATION_NL_TO_EN = "translation_nl_to_en" |
|
|
|
GENERATOR_LIST = [ |
|
{ |
|
"model_name": "yhavinga/ul2-base-en-nl", |
|
"desc": "UL2 base en->nl", |
|
"task": TRANSLATION_EN_TO_NL, |
|
"split_sentences": False, |
|
}, |
|
|
|
|
|
|
|
|
|
|
|
|
|
{ |
|
"model_name": "Helsinki-NLP/opus-mt-en-nl", |
|
"desc": "Opus MT en->nl", |
|
"task": TRANSLATION_EN_TO_NL, |
|
"split_sentences": True, |
|
}, |
|
{ |
|
"model_name": "Helsinki-NLP/opus-mt-nl-en", |
|
"desc": "Opus MT nl->en", |
|
"task": TRANSLATION_NL_TO_EN, |
|
"split_sentences": True, |
|
}, |
|
|
|
|
|
|
|
|
|
|
|
|
|
{ |
|
"model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-neddx2-nl-en", |
|
"desc": "Long t5 large-nl8 nl-en", |
|
"task": TRANSLATION_NL_TO_EN, |
|
"split_sentences": False, |
|
}, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
] |
|
|
|
|
|
class StreamlitTextIteratorStreamer(TextIteratorStreamer): |
|
def __init__( |
|
self, output_placeholder, tokenizer, skip_prompt=False, **decode_kwargs |
|
): |
|
super().__init__(tokenizer, skip_prompt, **decode_kwargs) |
|
self.output_placeholder = output_placeholder |
|
self.output_text = "" |
|
|
|
def on_finalized_text(self, text: str, stream_end: bool = False): |
|
self.output_text += text |
|
self.output_placeholder.markdown(self.output_text, unsafe_allow_html=True) |
|
super().on_finalized_text(text, stream_end) |
|
|
|
|
|
def main(): |
|
st.set_page_config( |
|
page_title="Rosetta en/nl", |
|
layout="wide", |
|
initial_sidebar_state="auto", |
|
page_icon="๐", |
|
) |
|
|
|
if "generators" not in st.session_state: |
|
st.session_state["generators"] = GeneratorFactory(GENERATOR_LIST) |
|
generators = st.session_state["generators"] |
|
|
|
with open("style.css") as f: |
|
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) |
|
|
|
st.sidebar.image("rosetta.png", width=200) |
|
st.sidebar.markdown( |
|
"""# Rosetta |
|
Vertaal van en naar Engels""" |
|
) |
|
|
|
default_text = st.sidebar.radio( |
|
"Change default text", |
|
tuple(default_texts.keys()), |
|
index=0, |
|
) |
|
if default_text or "prompt_box" not in st.session_state: |
|
st.session_state["prompt_box"] = default_texts[default_text]["text"] |
|
|
|
|
|
left, right = st.columns(2) |
|
text_area = left.text_area("Enter text", st.session_state.prompt_box, height=500) |
|
st.session_state["text"] = text_area |
|
|
|
|
|
st.sidebar.title("Parameters:") |
|
num_beams = st.sidebar.number_input("Num beams", min_value=1, max_value=10, value=1) |
|
num_beam_groups = st.sidebar.number_input( |
|
"Num beam groups", min_value=1, max_value=10, value=1 |
|
) |
|
length_penalty = st.sidebar.number_input( |
|
"Length penalty", min_value=0.0, max_value=2.0, value=1.2, step=0.1 |
|
) |
|
diversity_penalty = st.sidebar.number_input( |
|
"Diversity penalty", min_value=0.0, max_value=2.0, value=0.1, step=0.1 |
|
) |
|
st.sidebar.markdown( |
|
"""For an explanation of the parameters, head over to the [Huggingface blog post about text generation](https://huggingface.co/blog/how-to-generate) |
|
and the [Huggingface text generation interface doc](https://huggingface.co/transformers/main_classes/model.html?highlight=generate#transformers.generation_utils.GenerationMixin.generate). |
|
""" |
|
) |
|
params = { |
|
"num_beams": num_beams, |
|
"num_beam_groups": num_beam_groups, |
|
"diversity_penalty": diversity_penalty if num_beam_groups > 1 else 0.0, |
|
"length_penalty": length_penalty, |
|
"early_stopping": True, |
|
} |
|
|
|
if left.button("Run"): |
|
memory = psutil.virtual_memory() |
|
|
|
language = detect(st.session_state.text) |
|
if language == "en": |
|
task = TRANSLATION_EN_TO_NL |
|
elif language == "nl": |
|
task = TRANSLATION_NL_TO_EN |
|
else: |
|
left.error(f"Language {language} not supported") |
|
return |
|
|
|
|
|
if num_beams % num_beam_groups != 0: |
|
left.error("Num beams should be a multiple of num beam groups") |
|
return |
|
|
|
streaming_enabled = num_beams == 1 |
|
if not streaming_enabled: |
|
left.markdown("*`num_beams > 1` so streaming is disabled*") |
|
|
|
for generator in generators.filter(task=task): |
|
model_container = right.container() |
|
model_container.markdown(f"๐งฎ **Model `{generator}`**") |
|
output_placeholder = model_container.empty() |
|
streamer = ( |
|
StreamlitTextIteratorStreamer(output_placeholder, generator.tokenizer) |
|
if streaming_enabled |
|
else None |
|
) |
|
time_start = time.time() |
|
result, params_used = generator.generate( |
|
text=st.session_state.text, streamer=streamer, **params |
|
) |
|
time_end = time.time() |
|
time_diff = time_end - time_start |
|
|
|
if not streaming_enabled: |
|
right.write(result.replace("\n", " \n")) |
|
text_line = ", ".join([f"{k}={v}" for k, v in params_used.items()]) |
|
right.markdown(f" ๐ *generated in {time_diff:.2f}s, `{text_line}`*") |
|
|
|
st.write( |
|
f""" |
|
--- |
|
*Memory: {memory.total / 10**9:.2f}GB, used: {memory.percent}%, available: {memory.available / 10**9:.2f}GB* |
|
""" |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|