import time import psutil import streamlit as st import torch from langdetect import detect from transformers import TextIteratorStreamer from default_texts import default_texts from generator import GeneratorFactory device = torch.cuda.device_count() - 1 TRANSLATION_EN_TO_NL = "translation_en_to_nl" TRANSLATION_NL_TO_EN = "translation_nl_to_en" GENERATOR_LIST = [ { "model_name": "yhavinga/ul2-base-en-nl", "desc": "UL2 base en->nl", "task": TRANSLATION_EN_TO_NL, "split_sentences": False, }, # { # "model_name": "yhavinga/ul2-large-en-nl", # "desc": "UL2 large en->nl", # "task": TRANSLATION_EN_TO_NL, # "split_sentences": False, # }, { "model_name": "Helsinki-NLP/opus-mt-en-nl", "desc": "Opus MT en->nl", "task": TRANSLATION_EN_TO_NL, "split_sentences": True, }, { "model_name": "Helsinki-NLP/opus-mt-nl-en", "desc": "Opus MT nl->en", "task": TRANSLATION_NL_TO_EN, "split_sentences": True, }, # { # "model_name": "yhavinga/t5-small-24L-ccmatrix-multi", # "desc": "T5 small nl24 ccmatrix nl-en", # "task": TRANSLATION_NL_TO_EN, # "split_sentences": True, # }, { "model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-neddx2-nl-en", "desc": "Long t5 large-nl8 nl-en", "task": TRANSLATION_NL_TO_EN, "split_sentences": False, }, # { # "model_name": "yhavinga/byt5-small-ccmatrix-en-nl", # "desc": "ByT5 small ccmatrix en->nl", # "task": TRANSLATION_EN_TO_NL, # "split_sentences": True, # }, # { # "model_name": "yhavinga/t5-base-36L-ccmatrix-multi", # "desc": "T5 base nl36 ccmatrix en->nl", # "task": TRANSLATION_EN_TO_NL, # "split_sentences": True, # }, # { ] class StreamlitTextIteratorStreamer(TextIteratorStreamer): def __init__( self, output_placeholder, tokenizer, skip_prompt=False, **decode_kwargs ): super().__init__(tokenizer, skip_prompt, **decode_kwargs) self.output_placeholder = output_placeholder self.output_text = "" def on_finalized_text(self, text: str, stream_end: bool = False): self.output_text += text self.output_placeholder.markdown(self.output_text, unsafe_allow_html=True) super().on_finalized_text(text, stream_end) def main(): st.set_page_config( # Alternate names: setup_page, page, layout page_title="Rosetta en/nl", # String or None. Strings get appended with "• Streamlit". layout="wide", # Can be "centered" or "wide". In the future also "dashboard", etc. initial_sidebar_state="auto", # Can be "auto", "expanded", "collapsed" page_icon="📑", # String, anything supported by st.image, or None. ) if "generators" not in st.session_state: st.session_state["generators"] = GeneratorFactory(GENERATOR_LIST) generators = st.session_state["generators"] with open("style.css") as f: st.markdown(f"", unsafe_allow_html=True) st.sidebar.image("rosetta.png", width=200) st.sidebar.markdown( """# Rosetta Vertaal van en naar Engels""" ) default_text = st.sidebar.radio( "Change default text", tuple(default_texts.keys()), index=0, ) if default_text or "prompt_box" not in st.session_state: st.session_state["prompt_box"] = default_texts[default_text]["text"] # create a left and right column left, right = st.columns(2) text_area = left.text_area("Enter text", st.session_state.prompt_box, height=500) st.session_state["text"] = text_area # Sidebar parameters st.sidebar.title("Parameters:") num_beams = st.sidebar.number_input("Num beams", min_value=1, max_value=10, value=1) num_beam_groups = st.sidebar.number_input( "Num beam groups", min_value=1, max_value=10, value=1 ) length_penalty = st.sidebar.number_input( "Length penalty", min_value=0.0, max_value=2.0, value=1.0, step=0.1 ) diversity_penalty = st.sidebar.number_input( "Diversity penalty", min_value=0.0, max_value=2.0, value=0.1, step=0.1 ) st.sidebar.markdown( """For an explanation of the parameters, head over to the [Huggingface blog post about text generation](https://huggingface.co/blog/how-to-generate) and the [Huggingface text generation interface doc](https://huggingface.co/transformers/main_classes/model.html?highlight=generate#transformers.generation_utils.GenerationMixin.generate). """ ) params = { "num_beams": num_beams, "num_beam_groups": num_beam_groups, "diversity_penalty": diversity_penalty if num_beam_groups > 1 else 0.0, "length_penalty": length_penalty if num_beams > 1 else 1.0, "early_stopping": True, } if left.button("Run"): memory = psutil.virtual_memory() language = detect(st.session_state.text) if language == "en": task = TRANSLATION_EN_TO_NL elif language == "nl": task = TRANSLATION_NL_TO_EN else: left.error(f"Language {language} not supported") return # Num beam groups should be a divisor of num beams if num_beams % num_beam_groups != 0: left.error("Num beams should be a multiple of num beam groups") return streaming_enabled = num_beams == 1 if not streaming_enabled: left.markdown("*`num_beams > 1` so streaming is disabled*") for generator in generators.filter(task=task): model_container = right.container() model_container.markdown(f"🧮 **Model `{generator}`**") output_placeholder = model_container.empty() streamer = ( StreamlitTextIteratorStreamer(output_placeholder, generator.tokenizer) if streaming_enabled else None ) time_start = time.time() result, params_used = generator.generate( text=st.session_state.text, streamer=streamer, **params ) time_end = time.time() time_diff = time_end - time_start if not streaming_enabled: right.write(result.replace("\n", " \n")) text_line = ", ".join([f"{k}={v}" for k, v in params_used.items()]) right.markdown(f" 🕙 *generated in {time_diff:.2f}s, `{text_line}`*") st.write( f""" --- *Memory: {memory.total / 10**9:.2f}GB, used: {memory.percent}%, available: {memory.available / 10**9:.2f}GB* """ ) if __name__ == "__main__": main()