File size: 8,252 Bytes
46ffa30 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 |
import json
import os
import time
from random import randint
import psutil
import streamlit as st
import torch
from transformers import (
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoTokenizer,
pipeline,
set_seed,
)
from generator import GeneratorFactory
device = torch.cuda.device_count() - 1
TRANSLATION_NL_TO_EN = "translation_en_to_nl"
GENERATOR_LIST = [
{
"model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512l-nedd-256ccmatrix-en-nl",
"desc": "longT5 large nl8 256cc/512beta/512l en->nl",
"task": TRANSLATION_NL_TO_EN,
},
{
"model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512-nedd-en-nl",
"desc": "longT5 large nl8 512beta/512l en->nl",
"task": TRANSLATION_NL_TO_EN,
},
{
"model_name": "yhavinga/t5-small-24L-ccmatrix-multi",
"desc": "T5 small nl24 ccmatrix en->nl",
"task": TRANSLATION_NL_TO_EN,
},
]
def main():
st.set_page_config( # Alternate names: setup_page, page, layout
page_title="Babel", # String or None. Strings get appended with "β’ Streamlit".
layout="wide", # Can be "centered" or "wide". In the future also "dashboard", etc.
initial_sidebar_state="expanded", # Can be "auto", "expanded", "collapsed"
page_icon="π", # String, anything supported by st.image, or None.
)
if "generators" not in st.session_state:
st.session_state["generators"] = GeneratorFactory(GENERATOR_LIST)
generators = st.session_state["generators"]
with open("style.css") as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
st.sidebar.image("babel.png", width=200)
st.sidebar.markdown(
"""# Babel
Vertaal van en naar Engels"""
)
model_desc = st.sidebar.selectbox("Model", generators.gpt_descs(), index=1)
st.sidebar.title("Parameters:")
if "prompt_box" not in st.session_state:
# Text is from https://www.gutenberg.org/files/35091/35091-h/35091-h.html
st.session_state[
"prompt_box"
] = """It was a wet, gusty night and I had a lonely walk home. By taking the river road, though I hated it, I saved two miles, so I sloshed ahead trying not to think at all. Through the barbed wire fence I could see the racing river. Its black swollen body writhed along with extraordinary swiftness, breathlessly silent, only occasionally making a swishing ripple. I did not enjoy looking at it. I was somehow afraid.
And there, at the end of the river road where I swerved off, a figure stood waiting for me, motionless and enigmatic. I had to meet it or turn back.
It was a quite young girl, unknown to me, with a hood over her head, and with large unhappy eyes.
βMy father is very ill,β she said without a word of introduction. βThe nurse is frightened. Could you come in and help?β"""
st.session_state["text"] = st.text_area(
"Enter text", st.session_state.prompt_box, height=300
)
max_length = st.sidebar.number_input(
"Lengte van de tekst",
value=200,
max_value=4096,
)
no_repeat_ngram_size = st.sidebar.number_input(
"No-repeat NGram size", min_value=1, max_value=5, value=3
)
repetition_penalty = st.sidebar.number_input(
"Repetition penalty", min_value=0.0, max_value=5.0, value=1.2, step=0.1
)
num_return_sequences = st.sidebar.number_input(
"Num return sequences", min_value=1, max_value=5, value=1
)
seed_placeholder = st.sidebar.empty()
if "seed" not in st.session_state:
print(f"Session state does not contain seed")
st.session_state["seed"] = 4162549114
print(f"Seed is set to: {st.session_state['seed']}")
seed = seed_placeholder.number_input(
"Seed", min_value=0, max_value=2**32 - 1, value=st.session_state["seed"]
)
def set_random_seed():
st.session_state["seed"] = randint(0, 2**32 - 1)
seed = seed_placeholder.number_input(
"Seed", min_value=0, max_value=2**32 - 1, value=st.session_state["seed"]
)
print(f"New random seed set to: {seed}")
if st.button("Set new random seed"):
set_random_seed()
if sampling_mode := st.sidebar.selectbox(
"select a Mode", index=0, options=["Top-k Sampling", "Beam Search"]
):
if sampling_mode == "Beam Search":
num_beams = st.sidebar.number_input(
"Num beams", min_value=1, max_value=10, value=4
)
length_penalty = st.sidebar.number_input(
"Length penalty", min_value=0.0, max_value=2.0, value=1.0, step=0.1
)
params = {
"max_length": max_length,
"no_repeat_ngram_size": no_repeat_ngram_size,
"repetition_penalty": repetition_penalty,
"num_return_sequences": num_return_sequences,
"num_beams": num_beams,
"early_stopping": True,
"length_penalty": length_penalty,
}
else:
top_k = st.sidebar.number_input(
"Top K", min_value=0, max_value=100, value=50
)
top_p = st.sidebar.number_input(
"Top P", min_value=0.0, max_value=1.0, value=0.95, step=0.05
)
temperature = st.sidebar.number_input(
"Temperature", min_value=0.05, max_value=1.0, value=1.0, step=0.05
)
params = {
"max_length": max_length,
"no_repeat_ngram_size": no_repeat_ngram_size,
"repetition_penalty": repetition_penalty,
"num_return_sequences": num_return_sequences,
"do_sample": True,
"top_k": top_k,
"top_p": top_p,
"temperature": temperature,
}
st.sidebar.markdown(
"""For an explanation of the parameters, head over to the [Huggingface blog post about text generation](https://huggingface.co/blog/how-to-generate)
and the [Huggingface text generation interface doc](https://huggingface.co/transformers/main_classes/model.html?highlight=generate#transformers.generation_utils.GenerationMixin.generate).
"""
)
def estimate_time():
"""Estimate the time it takes to generate the text."""
estimate = max_length / 18
if device == -1:
## cpu
estimate = estimate * (1 + 0.7 * (num_return_sequences - 1))
if sampling_mode == "Beam Search":
estimate = estimate * (1.1 + 0.3 * (num_beams - 1))
else:
## gpu
estimate = estimate * (1 + 0.1 * (num_return_sequences - 1))
estimate = 0.5 + estimate / 5
if sampling_mode == "Beam Search":
estimate = estimate * (1.0 + 0.1 * (num_beams - 1))
return int(estimate)
if st.button("Run"):
estimate = estimate_time()
with st.spinner(
text=f"Please wait ~ {estimate} second{'s' if estimate != 1 else ''} while getting results ..."
):
memory = psutil.virtual_memory()
for generator in generators:
st.subheader(f"Result from {generator}")
set_seed(seed)
time_start = time.time()
result = generator.generate(text=st.session_state.text, **params)
time_end = time.time()
time_diff = time_end - time_start
for text in result:
st.write(text.replace("\n", " \n"))
st.write(f"--- generated in {time_diff:.2f} seconds ---")
info = f"""
---
*Memory: {memory.total / 10**9:.2f}GB, used: {memory.percent}%, available: {memory.available / 10**9:.2f}GB*
*Text generated using seed {seed}*
"""
st.write(info)
params["seed"] = seed
params["prompt"] = st.session_state.text
params["model"] = generator.model_name
params_text = json.dumps(params)
print(params_text)
st.json(params_text)
if __name__ == "__main__":
main()
|