|
import streamlit as st |
|
from .services import TextGeneration |
|
from tokenizers import Tokenizer |
|
from functools import lru_cache |
|
|
|
|
|
@lru_cache(maxsize=1) |
|
def load_text_generator(): |
|
generator = TextGeneration() |
|
generator.load() |
|
return generator |
|
|
|
|
|
generator = load_text_generator() |
|
|
|
qa_prompt = """ |
|
أجب عن السؤال التالي: |
|
""" |
|
qa_prompt_post = """ الجواب هو """ |
|
qa_prompt_post_year = """ في سنة: """ |
|
|
|
|
|
def write(): |
|
st.markdown( |
|
""" |
|
<h1 style="text-align:left;">Arabic Language Generation</h1> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
|
|
|
|
|
|
st.sidebar.subheader("Configurable parameters") |
|
|
|
model_name = st.sidebar.selectbox( |
|
"Model Selector", |
|
options=[ |
|
"AraGPT2-Base", |
|
|
|
|
|
"AraGPT2-Mega", |
|
], |
|
index=0, |
|
) |
|
|
|
max_new_tokens = st.sidebar.number_input( |
|
"Maximum length", |
|
min_value=0, |
|
max_value=1024, |
|
value=100, |
|
help="The maximum length of the sequence to be generated.", |
|
) |
|
temp = st.sidebar.slider( |
|
"Temperature", |
|
value=1.0, |
|
min_value=0.1, |
|
max_value=100.0, |
|
help="The value used to module the next token probabilities.", |
|
) |
|
top_k = st.sidebar.number_input( |
|
"Top k", |
|
value=10, |
|
help="The number of highest probability vocabulary tokens to keep for top-k-filtering.", |
|
) |
|
top_p = st.sidebar.number_input( |
|
"Top p", |
|
value=0.95, |
|
help=" If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.", |
|
) |
|
do_sample = st.sidebar.selectbox( |
|
"Sampling?", |
|
(True, False), |
|
help="Whether or not to use sampling; use greedy decoding otherwise.", |
|
) |
|
num_beams = st.sidebar.number_input( |
|
"Number of beams", |
|
min_value=1, |
|
max_value=10, |
|
value=3, |
|
help="The number of beams to use for beam search.", |
|
) |
|
repetition_penalty = st.sidebar.number_input( |
|
"Repetition Penalty", |
|
min_value=0.0, |
|
value=3.0, |
|
step=0.1, |
|
help="The parameter for repetition penalty. 1.0 means no penalty", |
|
) |
|
no_repeat_ngram_size = st.sidebar.number_input( |
|
"No Repeat N-Gram Size", |
|
min_value=0, |
|
value=3, |
|
help="If set to int > 0, all ngrams of that size can only occur once.", |
|
) |
|
|
|
st.write("#") |
|
|
|
col = st.columns(2) |
|
|
|
col[0].image("images/AraGPT2.png", width=200) |
|
|
|
st.markdown( |
|
""" |
|
|
|
<h3 style="text-align:left;">AraGPT2 is GPT2 model trained from scratch on 77GB of Arabic text.</h3> |
|
<h4 style="text-align:left;"> More details in our <a href="https://github.com/aub-mind/arabert/tree/master/aragpt2">repo</a>.</h4> |
|
|
|
<p style="text-align:left;"><p> |
|
<p style="text-align:left;">Use the generation paramters on the sidebar to adjust generation quality.</p> |
|
<p style="text-align:right;"><p> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
st.markdown( |
|
""" |
|
<style> |
|
p, div, input, label, textarea{ |
|
text-align: right; |
|
} |
|
</style> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
prompt = st.text_area( |
|
"Prompt", |
|
"يحكى أن مزارعا مخادعا قام ببيع بئر الماء الموجود في أرضه لجاره مقابل مبلغ كبير من المال", |
|
) |
|
if st.button("Generate"): |
|
with st.spinner("Generating..."): |
|
generated_text = generator.generate( |
|
prompt=prompt, |
|
model_name=model_name, |
|
max_new_tokens=max_new_tokens, |
|
temperature=temp, |
|
top_k=top_k, |
|
top_p=top_p, |
|
repetition_penalty=repetition_penalty, |
|
do_sample=do_sample, |
|
num_beams=num_beams, |
|
no_repeat_ngram_size=no_repeat_ngram_size, |
|
) |
|
st.write(generated_text) |
|
|
|
st.markdown("---") |
|
st.subheader("") |
|
st.markdown( |
|
""" |
|
<p style="text-align:left;"><p> |
|
<h2 style="text-align:left;">Zero-Shot Question Answering</h2> |
|
|
|
<p style="text-align:left;">Adjust the maximum length to closely match the expected output length. Setting the Sampling paramter to False is recommended</p> |
|
<p style="text-align:left;"><p> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
question = st.text_input( |
|
"Question", "من كان رئيس ألمانيا النازية في الحرب العالمية الثانية ؟" |
|
) |
|
is_date = st.checkbox("Help the model: Is the answer a date?") |
|
if st.button("Answer"): |
|
|
|
prompt2 = qa_prompt + question + qa_prompt_post |
|
if is_date: |
|
prompt2 += qa_prompt_post_year |
|
else: |
|
prompt2 += " : " |
|
with st.spinner("Thinking..."): |
|
answer = generator.generate( |
|
prompt=prompt2, |
|
model_name=model_name, |
|
max_new_tokens=max_new_tokens, |
|
temperature=temp, |
|
top_k=top_k, |
|
top_p=top_p, |
|
repetition_penalty=repetition_penalty, |
|
do_sample=do_sample, |
|
num_beams=num_beams, |
|
no_repeat_ngram_size=no_repeat_ngram_size, |
|
) |
|
st.write(answer) |
|
|