File size: 2,714 Bytes
463444e
f50408f
c3e5a3b
463444e
 
f50408f
93d168d
463444e
c3e5a3b
463444e
 
 
 
 
06d2814
93d168d
 
06d2814
 
 
 
463444e
 
 
06d2814
f50408f
06d2814
463444e
48ff56c
f50408f
 
 
48ff56c
93d168d
f50408f
 
 
93d168d
48ff56c
 
 
 
f50408f
 
48ff56c
9003587
48ff56c
 
 
f50408f
9003587
48ff56c
f50408f
 
 
48ff56c
9003587
48ff56c
 
463444e
93d168d
06d2814
f50408f
463444e
 
 
9003587
 
 
463444e
 
 
93d168d
9003587
f50408f
463444e
f50408f
463444e
 
9003587
 
 
463444e
9003587
463444e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import time
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from flores200_codes import flores_codes

# Load models and tokenizers once during initialization
def load_models():
    model_name_dict = {
        "nllb-distilled-600M": "facebook/nllb-200-distilled-600M",
    }

    model_dict = {}

    for call_name, real_name in model_name_dict.items():
        print("\tLoading model:", call_name)
        model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
        tokenizer = AutoTokenizer.from_pretrained(real_name)
        model_dict[call_name] = {
            "model": model,
            "tokenizer": tokenizer,
        }

    return model_dict

# Translate text using preloaded models and tokenizers
def translate_text(source, target, text, model_dict):
    model_name = "nllb-distilled-600M"

    if model_name in model_dict:
        model_info = model_dict[model_name]
        model = model_info["model"]
        tokenizer = model_info["tokenizer"]

        start_time = time.time()

        source_code = flores_codes[source]
        target_code = flores_codes[target]

        translator = pipeline(
            "translation",
            model=model,
            tokenizer=tokenizer,
            src_lang=source_code,
            tgt_lang=target_code,
        )
        output = translator(text, max_length=400)

        end_time = time.time()

        output_text = output[0]["translation_text"]
        result = {
            "inference_time": end_time - start_time,
            "source": source_code,
            "target": target_code,
            "result": output_text,
        }
        return result
    else:
        raise KeyError(f"Model '{model_name}' not found in model_dict")

if __name__ == "__main__":
    print("\tInitializing models")
    model_dict = load_models()

    lang_codes = list(flores_codes.keys())
    inputs = [
        gr.inputs.Dropdown(lang_codes, default="English", label="Source"),
        gr.inputs.Dropdown(lang_codes, default="Nepali", label="Target"),
        gr.inputs.Textbox(lines=5, label="Input text"),
    ]
    outputs = gr.outputs.JSON()

    title = "The Master Betters Translator"
    description = (
        "This is a beta version of The Master Betters Translator that utilizes pre-trained language models for translation. To use this app you need to have chosen the source and target language with your input text to get the output."
    )
    examples = [["English", "Nepali", "Hello, how are you"]]

    gr.Interface(
        translate_text,
        inputs,
        outputs,
        title=title,
        description=description,
        examples=examples,
        examples_per_page=50,
    ).launch()