File size: 4,568 Bytes
46ffa30
3f553b1
46ffa30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f553b1
46ffa30
 
 
 
 
3f553b1
46ffa30
 
 
 
 
3f553b1
46ffa30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a19a543
3f553b1
a19a543
46ffa30
3f553b1
 
46ffa30
 
 
 
 
 
 
3f553b1
 
 
 
 
46ffa30
 
3f553b1
 
 
 
 
 
 
 
 
 
 
a19a543
3f553b1
 
 
 
 
 
 
 
 
46ffa30
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import time
import torch

import psutil
import streamlit as st

from generator import GeneratorFactory

device = torch.cuda.device_count() - 1

TRANSLATION_NL_TO_EN = "translation_en_to_nl"

GENERATOR_LIST = [
    {
        "model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512l-nedd-256ccmatrix-en-nl",
        "desc": "longT5 large nl8 256cc/512beta/512l en->nl",
        "task": TRANSLATION_NL_TO_EN,
        "split_sentences": False,
    },
    {
        "model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512-nedd-en-nl",
        "desc": "longT5 large nl8 512beta/512l en->nl",
        "task": TRANSLATION_NL_TO_EN,
        "split_sentences": False,
    },
    {
        "model_name": "yhavinga/t5-small-24L-ccmatrix-multi",
        "desc": "T5 small nl24 ccmatrix en->nl",
        "task": TRANSLATION_NL_TO_EN,
        "split_sentences": True,
    },
]


def main():
    st.set_page_config(  # Alternate names: setup_page, page, layout
        page_title="Babel",  # String or None. Strings get appended with "โ€ข Streamlit".
        layout="wide",  # Can be "centered" or "wide". In the future also "dashboard", etc.
        initial_sidebar_state="expanded",  # Can be "auto", "expanded", "collapsed"
        page_icon="๐Ÿ“š",  # String, anything supported by st.image, or None.
    )

    if "generators" not in st.session_state:
        st.session_state["generators"] = GeneratorFactory(GENERATOR_LIST)
    generators = st.session_state["generators"]

    with open("style.css") as f:
        st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
    st.sidebar.image("babel.png", width=200)
    st.sidebar.markdown(
        """# Babel
    Vertaal van en naar Engels"""
    )
    st.sidebar.title("Parameters:")
    if "prompt_box" not in st.session_state:
        # Text is from https://www.gutenberg.org/files/35091/35091-h/35091-h.html
        st.session_state[
            "prompt_box"
        ] = """It was a wet, gusty night and I had a lonely walk home. By taking the river road, though I hated it, I saved two miles, so I sloshed ahead trying not to think at all. Through the barbed wire fence I could see the racing river. Its black swollen body writhed along with extraordinary swiftness, breathlessly silent, only occasionally making a swishing ripple. I did not enjoy looking at it. I was somehow afraid.

And there, at the end of the river road where I swerved off, a figure stood waiting for me, motionless and enigmatic. I had to meet it or turn back.

It was a quite young girl, unknown to me, with a hood over her head, and with large unhappy eyes.

โ€œMy father is very ill,โ€ she said without a word of introduction. โ€œThe nurse is frightened. Could you come in and help?โ€"""
    st.session_state["text"] = st.text_area(
        "Enter text", st.session_state.prompt_box, height=300
    )
    num_beams = st.sidebar.number_input("Num beams", min_value=1, max_value=10, value=1)
    num_beam_groups = st.sidebar.number_input(
        "Num beam groups", min_value=1, max_value=10, value=1
    )
    length_penalty = st.sidebar.number_input(
        "Length penalty", min_value=0.0, max_value=2.0, value=1.2, step=0.1
    )
    st.sidebar.markdown(
        """For an explanation of the parameters, head over to the [Huggingface blog post about text generation](https://huggingface.co/blog/how-to-generate)
and the [Huggingface text generation interface doc](https://huggingface.co/transformers/main_classes/model.html?highlight=generate#transformers.generation_utils.GenerationMixin.generate).
"""
    )

    params = {
        "num_beams": num_beams,
        "num_beam_groups": num_beam_groups,
        "length_penalty": length_penalty,
    }

    if st.button("Run"):
        memory = psutil.virtual_memory()

        for generator in generators:
            st.markdown(f"๐Ÿงฎ **Model `{generator}`**")
            time_start = time.time()
            result, params_used = generator.generate(
                text=st.session_state.text, **params
            )
            time_end = time.time()
            time_diff = time_end - time_start

            st.write(result.replace("\n", "  \n"))
            text_line = ", ".join([f"{k}={v}" for k, v in params_used.items()])
            st.markdown(f"    ๐Ÿ•™ *generated in {time_diff:.2f}s, `{text_line}`*")

        st.write(
            f"""
        ---
        *Memory: {memory.total / 10**9:.2f}GB, used: {memory.percent}%, available: {memory.available / 10**9:.2f}GB*
        """
        )


if __name__ == "__main__":
    main()