arsath-sm's picture
Update app.py
3ab7e14 verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import re
# Load the model and tokenizer
model_name = 'abinayam/gpt-2-tamil'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# System prompt
system_prompt = """You are an expert Tamil language model specializing in spelling and grammar correction. Your task is to:
1. Correct any spelling errors in the given text.
2. Fix grammatical mistakes, including proper application of sandhi rules.
3. Ensure the corrected text maintains the original meaning and context.
4. Provide the corrected version of the entire input text.
Remember to preserve the structure and intent of the original text while making necessary corrections."""
# Common error corrections
common_errors = {
'பழங்கல்': 'பழங்கள்',
# Add more common spelling errors here
}
def apply_sandhi_rules(text):
# Apply sandhi rules
text = re.sub(r'(கு|க்கு)\s+(ப|த|க|ச)', r'\1ப் \2', text)
# Add more sandhi rules as needed
return text
def preprocess_text(text):
# Apply common error corrections
for error, correction in common_errors.items():
text = text.replace(error, correction)
return text
def postprocess_text(text):
# Apply sandhi rules
text = apply_sandhi_rules(text)
return text
def correct_text(input_text):
# Preprocess the input text
preprocessed_text = preprocess_text(input_text)
# Prepare the full prompt with system prompt and input text
full_prompt = f"{system_prompt}\n\nInput: {preprocessed_text}\n\nCorrected:"
# Tokenize the full prompt
input_ids = tokenizer.encode(full_prompt, return_tensors='pt')
# Generate corrected text
with torch.no_grad():
output = model.generate(
input_ids,
max_length=len(input_ids[0]) + 100, # Adjust based on expected output length
num_return_sequences=1,
temperature=0.7,
do_sample=True,
top_k=50,
top_p=0.95
)
# Decode the generated text
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
# Extract the corrected text (everything after "Corrected:")
corrected_text = generated_text.split("Corrected:")[-1].strip()
# Postprocess the corrected text
final_text = postprocess_text(corrected_text)
return final_text
# Create the Gradio interface
iface = gr.Interface(
fn=correct_text,
inputs=gr.Textbox(lines=5, placeholder="Enter Tamil text here..."),
outputs=gr.Textbox(label="Corrected Text"),
title="Tamil Spell Corrector and Grammar Checker",
description="This app uses the 'abinayam/gpt-2-tamil' model along with custom rules to correct spelling and grammar in Tamil text.",
examples=[
["நான் நேற்று கடைக்கு போனேன். அங்கே நிறைய பழங்கல் வாங்கினேன்."],
["நான் பள்ளிகு செல்கிறேன்."],
["அவன் வீட்டுகு வந்தான்."]
]
)
# Launch the app
iface.launch()