Elalimy's picture
Update app.py
db678d9 verified
raw
history blame
1.94 kB
import gradio as gr
from peft import PeftModel, PeftConfig
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
HUGGING_FACE_USER_NAME = "elalimy"
model_name = "my_awesome_peft_finetuned_helsinki_model"
peft_model_id = f"{HUGGING_FACE_USER_NAME}/{model_name}"
# Load model configuration (assuming it's saved locally)
config = PeftConfig.from_pretrained(peft_model_id)
# Load the base model from its local directory (replace with actual model type)
base_model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path, return_dict=True, load_in_8bit=False)
# Load the tokenizer from its local directory (replace with actual tokenizer type)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
# Load the Peft model (assuming it's a custom class or adaptation)
AI_model = PeftModel.from_pretrained(base_model, peft_model_id)
def generate_translation(source_text, device="cpu"):
# Encode the source text
input_ids = tokenizer.encode(source_text, return_tensors='pt').to(device)
# Move the model to the same device as input_ids
model = base_model.to(device)
# Generate the translation with adjusted decoding parameters
generated_ids = model.generate(
input_ids=input_ids,
max_length=512, # Adjust max_length if needed
num_beams=4,
length_penalty=5, # Adjust length_penalty if needed
no_repeat_ngram_size=4,
early_stopping=True
)
# Decode the generated translation excluding special tokens
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
return generated_text
def translate(text):
return generate_translation(text)
# Define the Gradio interface
iface = gr.Interface(
fn=translate,
inputs="text",
outputs="text",
title="Translation App",
description="Translate text using a fine-tuned Helsinki model."
)
# Launch the Gradio app
iface.launch()