Yehor's picture
Update app.py
7ed3b9a verified
import sys
import time
from importlib.metadata import version
import torch
import gradio as gr
from transformers import MBartForConditionalGeneration, AutoTokenizer
# Config
model_name = "/home/user/app/best-unlp"
concurrency_limit = 5
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the model
model = MBartForConditionalGeneration.from_pretrained(
model_name,
low_cpu_mem_usage=True,
device_map=device,
)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.src_lang = "uk_UA"
tokenizer.tgt_lang = "uk_UA"
examples = [
"привіт як справі?",
"як твої дела?",
]
title = "Grammar Correction for Ukrainian"
# https://www.tablesgenerator.com/markdown_tables
authors_table = """
## Authors
Follow them on social networks and **contact** if you need any help or have any questions:
| <img src="https://avatars.githubusercontent.com/u/7875085?v=4" width="100"> **Yehor Smoliakov** |
|-------------------------------------------------------------------------------------------------|
| https://t.me/smlkw in Telegram |
| https://x.com/yehor_smoliakov at X |
| https://github.com/egorsmkv at GitHub |
| https://huggingface.co/Yehor at Hugging Face |
| or use [email protected] |
""".strip()
description_head = f"""
# {title}
## Overview
This space uses https://huggingface.co/Pravopysnyk/best-unlp model.
Paste the text you want to enhance.
""".strip()
description_foot = f"""
{authors_table}
""".strip()
normalized_text_value = """
Corrected text will appear here.
Choose **an example** below the Correct button or paste **your text**.
""".strip()
tech_env = f"""
#### Environment
- Python: {sys.version}
""".strip()
tech_libraries = f"""
#### Libraries
- torch: {version('torch')}
- gradio: {version('gradio')}
- transformers: {version('transformers')}
""".strip()
def inference(text, progress=gr.Progress()):
if not text:
raise gr.Error("Please paste your text.")
gr.Info("Starting", duration=2)
progress(0, desc="Correcting...")
results = []
sentences = [
text,
]
for sentence in progress.tqdm(sentences, desc="Correcting...", unit="sentence"):
sentence = sentence.strip()
if len(sentence) == 0:
continue
t0 = time.time()
input_text = sentence
encoded_input = tokenizer(
input_text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=1024,
).to(device)
output_ids = model.generate(
**encoded_input, max_length=1024, num_beams=5, early_stopping=True
)
normalized_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
if not normalized_text:
normalized_text = "-"
elapsed_time = round(time.time() - t0, 2)
normalized_text = normalized_text.strip()
results.append(
{
"sentence": sentence,
"normalized_text": normalized_text,
"elapsed_time": elapsed_time,
}
)
gr.Info("Finished!", duration=2)
result_texts = []
for result in results:
result_texts.append(f'> {result["normalized_text"]}')
result_texts.append("\n")
sum_elapsed_text = sum([result["elapsed_time"] for result in results])
result_texts.append(f"Elapsed time: {sum_elapsed_text} seconds")
return "\n".join(result_texts)
demo = gr.Blocks(
title=title,
analytics_enabled=False,
# theme="huggingface",
theme=gr.themes.Base(),
)
with demo:
gr.Markdown(description_head)
gr.Markdown("## Usage")
with gr.Row():
text = gr.Textbox(label="Text", autofocus=True, max_lines=1)
normalized_text = gr.Textbox(
label="Corrected text",
placeholder=normalized_text_value,
show_copy_button=True,
)
gr.Button("Correct").click(
inference,
concurrency_limit=concurrency_limit,
inputs=text,
outputs=normalized_text,
)
with gr.Row():
gr.Examples(label="Choose an example", inputs=text, examples=examples)
gr.Markdown(description_foot)
gr.Markdown("### Gradio app uses:")
gr.Markdown(tech_env)
gr.Markdown(tech_libraries)
if __name__ == "__main__":
demo.queue()
demo.launch()