import sys import time from importlib.metadata import version import torch import gradio as gr from transformers import MBartForConditionalGeneration, AutoTokenizer # Config model_name = "/home/user/app/best-unlp" concurrency_limit = 5 device = "cuda" if torch.cuda.is_available() else "cpu" # Load the model model = MBartForConditionalGeneration.from_pretrained( model_name, low_cpu_mem_usage=True, device_map=device, ) model.eval() tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.src_lang = "uk_UA" tokenizer.tgt_lang = "uk_UA" examples = [ "привіт як справі?", "як твої дела?", ] title = "Grammar Correction for Ukrainian" # https://www.tablesgenerator.com/markdown_tables authors_table = """ ## Authors Follow them on social networks and **contact** if you need any help or have any questions: | **Yehor Smoliakov** | |-------------------------------------------------------------------------------------------------| | https://t.me/smlkw in Telegram | | https://x.com/yehor_smoliakov at X | | https://github.com/egorsmkv at GitHub | | https://huggingface.co/Yehor at Hugging Face | | or use egorsmkv@gmail.com | """.strip() description_head = f""" # {title} ## Overview This space uses https://huggingface.co/Pravopysnyk/best-unlp model. Paste the text you want to enhance. """.strip() description_foot = f""" {authors_table} """.strip() normalized_text_value = """ Corrected text will appear here. Choose **an example** below the Correct button or paste **your text**. """.strip() tech_env = f""" #### Environment - Python: {sys.version} """.strip() tech_libraries = f""" #### Libraries - torch: {version('torch')} - gradio: {version('gradio')} - transformers: {version('transformers')} """.strip() def inference(text, progress=gr.Progress()): if not text: raise gr.Error("Please paste your text.") gr.Info("Starting", duration=2) progress(0, desc="Correcting...") results = [] sentences = [ text, ] for sentence in progress.tqdm(sentences, desc="Correcting...", unit="sentence"): sentence = sentence.strip() if len(sentence) == 0: continue t0 = time.time() input_text = sentence encoded_input = tokenizer( input_text, return_tensors="pt", padding=True, truncation=True, max_length=1024, ).to(device) output_ids = model.generate( **encoded_input, max_length=1024, num_beams=5, early_stopping=True ) normalized_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) if not normalized_text: normalized_text = "-" elapsed_time = round(time.time() - t0, 2) normalized_text = normalized_text.strip() results.append( { "sentence": sentence, "normalized_text": normalized_text, "elapsed_time": elapsed_time, } ) gr.Info("Finished!", duration=2) result_texts = [] for result in results: result_texts.append(f'> {result["normalized_text"]}') result_texts.append("\n") sum_elapsed_text = sum([result["elapsed_time"] for result in results]) result_texts.append(f"Elapsed time: {sum_elapsed_text} seconds") return "\n".join(result_texts) demo = gr.Blocks( title=title, analytics_enabled=False, # theme="huggingface", theme=gr.themes.Base(), ) with demo: gr.Markdown(description_head) gr.Markdown("## Usage") with gr.Row(): text = gr.Textbox(label="Text", autofocus=True, max_lines=1) normalized_text = gr.Textbox( label="Corrected text", placeholder=normalized_text_value, show_copy_button=True, ) gr.Button("Correct").click( inference, concurrency_limit=concurrency_limit, inputs=text, outputs=normalized_text, ) with gr.Row(): gr.Examples(label="Choose an example", inputs=text, examples=examples) gr.Markdown(description_foot) gr.Markdown("### Gradio app uses:") gr.Markdown(tech_env) gr.Markdown(tech_libraries) if __name__ == "__main__": demo.queue() demo.launch()