Elalimy commited on
Commit
025659e
·
verified ·
1 Parent(s): f23dcf4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -0
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from peft import PeftModel, PeftConfig
3
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
+
5
+ HUGGING_FACE_USER_NAME = "elalimy"
6
+ model_name = "my_awesome_peft_finetuned_helsinki_model"
7
+ peft_model_id = f"{HUGGING_FACE_USER_NAME}/{model_name}"
8
+
9
+ # Load model configuration (assuming it's saved locally)
10
+ config = PeftConfig.from_pretrained(peft_model_id)
11
+ # Load the base model from its local directory (replace with actual model type)
12
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path, return_dict=True, load_in_8bit=False)
13
+
14
+ # Load the tokenizer from its local directory (replace with actual tokenizer type)
15
+ tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
16
+
17
+ # Load the Peft model (assuming it's a custom class or adaptation)
18
+ AI_model = PeftModel.from_pretrained(base_model, peft_model_id)
19
+
20
+ def generate_translation(source_text, device="cpu"):
21
+ # Encode the source text
22
+ input_ids = tokenizer.encode(source_text, return_tensors='pt').to(device)
23
+
24
+ # Move the model to the same device as input_ids
25
+ model = base_model.to(device)
26
+
27
+ # Generate the translation with adjusted decoding parameters
28
+ generated_ids = model.generate(
29
+ input_ids=input_ids,
30
+ max_length=512, # Adjust max_length if needed
31
+ num_beams=4,
32
+ length_penalty=5, # Adjust length_penalty if needed
33
+ no_repeat_ngram_size=4,
34
+ early_stopping=True
35
+ )
36
+
37
+ # Decode the generated translation excluding special tokens
38
+ generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
39
+
40
+ return generated_text
41
+
42
+ def translate(text):
43
+ return generate_translation(text)
44
+
45
+ # Define the Gradio interface
46
+ iface = gr.Interface(
47
+ fn=translate,
48
+ inputs="text",
49
+ outputs="text",
50
+ title="Translation App",
51
+ description="Translate text using a fine-tuned Helsinki model."
52
+ )
53
+
54
+ # Launch the Gradio app
55
+ iface.launch()