Spaces:
Running
Running
import sys | |
import time | |
from importlib.metadata import version | |
import torch | |
import gradio as gr | |
from transformers import MBartForConditionalGeneration, AutoTokenizer | |
# Config | |
model_name = "/home/user/app/best-unlp" | |
concurrency_limit = 5 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load the model | |
model = MBartForConditionalGeneration.from_pretrained( | |
model_name, | |
low_cpu_mem_usage=True, | |
device_map=device, | |
) | |
model.eval() | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
tokenizer.src_lang = "uk_UA" | |
tokenizer.tgt_lang = "uk_UA" | |
examples = [ | |
"привіт як справі?", | |
"як твої дела?", | |
] | |
title = "Grammar Correction for Ukrainian" | |
# https://www.tablesgenerator.com/markdown_tables | |
authors_table = """ | |
## Authors | |
Follow them on social networks and **contact** if you need any help or have any questions: | |
| <img src="https://avatars.githubusercontent.com/u/7875085?v=4" width="100"> **Yehor Smoliakov** | | |
|-------------------------------------------------------------------------------------------------| | |
| https://t.me/smlkw in Telegram | | |
| https://x.com/yehor_smoliakov at X | | |
| https://github.com/egorsmkv at GitHub | | |
| https://huggingface.co/Yehor at Hugging Face | | |
| or use [email protected] | | |
""".strip() | |
description_head = f""" | |
# {title} | |
## Overview | |
This space uses https://huggingface.co/Pravopysnyk/best-unlp model. | |
Paste the text you want to enhance. | |
""".strip() | |
description_foot = f""" | |
{authors_table} | |
""".strip() | |
normalized_text_value = """ | |
Corrected text will appear here. | |
Choose **an example** below the Correct button or paste **your text**. | |
""".strip() | |
tech_env = f""" | |
#### Environment | |
- Python: {sys.version} | |
""".strip() | |
tech_libraries = f""" | |
#### Libraries | |
- torch: {version('torch')} | |
- gradio: {version('gradio')} | |
- transformers: {version('transformers')} | |
""".strip() | |
def inference(text, progress=gr.Progress()): | |
if not text: | |
raise gr.Error("Please paste your text.") | |
gr.Info("Starting", duration=2) | |
progress(0, desc="Correcting...") | |
results = [] | |
sentences = [ | |
text, | |
] | |
for sentence in progress.tqdm(sentences, desc="Correcting...", unit="sentence"): | |
sentence = sentence.strip() | |
if len(sentence) == 0: | |
continue | |
t0 = time.time() | |
input_text = sentence | |
encoded_input = tokenizer( | |
input_text, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=1024, | |
).to(device) | |
output_ids = model.generate( | |
**encoded_input, max_length=1024, num_beams=5, early_stopping=True | |
) | |
normalized_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
if not normalized_text: | |
normalized_text = "-" | |
elapsed_time = round(time.time() - t0, 2) | |
normalized_text = normalized_text.strip() | |
results.append( | |
{ | |
"sentence": sentence, | |
"normalized_text": normalized_text, | |
"elapsed_time": elapsed_time, | |
} | |
) | |
gr.Info("Finished!", duration=2) | |
result_texts = [] | |
for result in results: | |
result_texts.append(f'> {result["normalized_text"]}') | |
result_texts.append("\n") | |
sum_elapsed_text = sum([result["elapsed_time"] for result in results]) | |
result_texts.append(f"Elapsed time: {sum_elapsed_text} seconds") | |
return "\n".join(result_texts) | |
demo = gr.Blocks( | |
title=title, | |
analytics_enabled=False, | |
# theme="huggingface", | |
theme=gr.themes.Base(), | |
) | |
with demo: | |
gr.Markdown(description_head) | |
gr.Markdown("## Usage") | |
with gr.Row(): | |
text = gr.Textbox(label="Text", autofocus=True, max_lines=1) | |
normalized_text = gr.Textbox( | |
label="Corrected text", | |
placeholder=normalized_text_value, | |
show_copy_button=True, | |
) | |
gr.Button("Correct").click( | |
inference, | |
concurrency_limit=concurrency_limit, | |
inputs=text, | |
outputs=normalized_text, | |
) | |
with gr.Row(): | |
gr.Examples(label="Choose an example", inputs=text, examples=examples) | |
gr.Markdown(description_foot) | |
gr.Markdown("### Gradio app uses:") | |
gr.Markdown(tech_env) | |
gr.Markdown(tech_libraries) | |
if __name__ == "__main__": | |
demo.queue() | |
demo.launch() | |