arsath-sm commited on
Commit
3ab7e14
·
verified ·
1 Parent(s): e2b6396

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -4
app.py CHANGED
@@ -8,6 +8,15 @@ model_name = 'abinayam/gpt-2-tamil'
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
  model = AutoModelForCausalLM.from_pretrained(model_name)
10
 
 
 
 
 
 
 
 
 
 
11
  # Common error corrections
12
  common_errors = {
13
  'பழங்கல்': 'பழங்கள்',
@@ -35,15 +44,29 @@ def correct_text(input_text):
35
  # Preprocess the input text
36
  preprocessed_text = preprocess_text(input_text)
37
 
38
- # Tokenize the preprocessed text
39
- input_ids = tokenizer.encode(preprocessed_text, return_tensors='pt')
 
 
 
40
 
41
  # Generate corrected text
42
  with torch.no_grad():
43
- output = model.generate(input_ids, max_length=100, num_return_sequences=1, temperature=0.7)
 
 
 
 
 
 
 
 
44
 
45
  # Decode the generated text
46
- corrected_text = tokenizer.decode(output[0], skip_special_tokens=True)
 
 
 
47
 
48
  # Postprocess the corrected text
49
  final_text = postprocess_text(corrected_text)
 
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
  model = AutoModelForCausalLM.from_pretrained(model_name)
10
 
11
+ # System prompt
12
+ system_prompt = """You are an expert Tamil language model specializing in spelling and grammar correction. Your task is to:
13
+ 1. Correct any spelling errors in the given text.
14
+ 2. Fix grammatical mistakes, including proper application of sandhi rules.
15
+ 3. Ensure the corrected text maintains the original meaning and context.
16
+ 4. Provide the corrected version of the entire input text.
17
+
18
+ Remember to preserve the structure and intent of the original text while making necessary corrections."""
19
+
20
  # Common error corrections
21
  common_errors = {
22
  'பழங்கல்': 'பழங்கள்',
 
44
  # Preprocess the input text
45
  preprocessed_text = preprocess_text(input_text)
46
 
47
+ # Prepare the full prompt with system prompt and input text
48
+ full_prompt = f"{system_prompt}\n\nInput: {preprocessed_text}\n\nCorrected:"
49
+
50
+ # Tokenize the full prompt
51
+ input_ids = tokenizer.encode(full_prompt, return_tensors='pt')
52
 
53
  # Generate corrected text
54
  with torch.no_grad():
55
+ output = model.generate(
56
+ input_ids,
57
+ max_length=len(input_ids[0]) + 100, # Adjust based on expected output length
58
+ num_return_sequences=1,
59
+ temperature=0.7,
60
+ do_sample=True,
61
+ top_k=50,
62
+ top_p=0.95
63
+ )
64
 
65
  # Decode the generated text
66
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
67
+
68
+ # Extract the corrected text (everything after "Corrected:")
69
+ corrected_text = generated_text.split("Corrected:")[-1].strip()
70
 
71
  # Postprocess the corrected text
72
  final_text = postprocess_text(corrected_text)