vineelnani16 commited on
Commit
2656ff7
1 Parent(s): 28e8d69

This file contains code to

Files changed (1) hide show
  1. app.py +86 -0
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from tokenizers import Tokenizer
4
+ import torch.nn.functional as F
5
+ import gradio as gr
6
+
7
+ # Enhanced Model with Custom Attention
8
+ class RachanaLLM(nn.Module):
9
+ def __init__(self, vocab_size=50000, embed_dim=284, num_heads=4, num_layers=4, dropout=0.2, max_len=256):
10
+ super(RachanaLLM, self).__init__()
11
+ self.embed_dim = embed_dim
12
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
13
+ self.positional_encoding = nn.Parameter(torch.zeros(1, max_len, embed_dim))
14
+ encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dropout=dropout, batch_first=True)
15
+ self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
16
+ self.output_layer = nn.Linear(embed_dim, vocab_size)
17
+
18
+ def forward(self, src, src_mask=None):
19
+ src = self.embedding(src) * torch.sqrt(torch.tensor(self.embed_dim, dtype=torch.float32))
20
+ src = src + self.positional_encoding[:, :src.size(1)]
21
+ output = self.transformer_encoder(src, src_key_padding_mask=src_mask)
22
+ return self.output_layer(output)
23
+
24
+ # Load model and tokenizer
25
+ checkpoint = torch.load("best_model.pth") # Use your actual path
26
+ model = RachanaLLM()
27
+ model.load_state_dict(checkpoint['model_state'] if 'model_state' in checkpoint else checkpoint)
28
+ model.eval()
29
+ tokenizer = Tokenizer.from_file("telugu_tokenizer_50k.json") # Use your actual path
30
+
31
+ # Advanced Beam Search with Temperature, Top-K Sampling, and Dynamic Repetition Penalty
32
+ def beam_search_decoder(model, tokenizer, input_ids, beam_width=3, max_length=20, temperature=1.0, top_k=5, repetition_penalty=1.2):
33
+ model.eval()
34
+ with torch.no_grad():
35
+ sequences = [[input_ids, 0.0]]
36
+
37
+ for _ in range(max_length):
38
+ all_candidates = []
39
+ for seq, score in sequences:
40
+ outputs = model(seq)
41
+ logits = outputs[:, -1] / temperature
42
+
43
+ # If a token has been used already, penalize its logit
44
+ for token_id in set(seq[0].tolist()):
45
+ if token_id in seq[0]:
46
+ logits[0][token_id] /= repetition_penalty
47
+
48
+ top_logits, top_indices = torch.topk(logits, top_k, dim=-1)
49
+ softmax_scores = F.log_softmax(top_logits, dim=-1)
50
+
51
+ for i in range(top_k):
52
+ next_token_id = top_indices[0][i].item()
53
+ if seq[0].tolist().count(next_token_id) < 2: # This ensures a token is penalized only if it appears more than once.
54
+ candidate = [torch.cat([seq, top_indices[:, i:i+1]], dim=1), score + softmax_scores[0][i].item()]
55
+ all_candidates.append(candidate)
56
+
57
+ # Sort candidates by score
58
+ ordered = sorted(all_candidates, key=lambda tup: tup[1], reverse=True)
59
+ sequences = ordered[:beam_width]
60
+
61
+ best_seq = sequences[0][0]
62
+ result_text = tokenizer.decode(best_seq[0].tolist())
63
+ return result_text
64
+
65
+
66
+
67
+ # Predict function for Gradio Interface
68
+ def predict(input_sentence):
69
+ input_ids = torch.tensor(tokenizer.encode(input_sentence).ids).unsqueeze(0)
70
+ generated_text = beam_search_decoder(model, tokenizer, input_ids, temperature=0.8, top_k=10)
71
+ return generated_text
72
+
73
+ # Gradio Interface
74
+ iface = gr.Interface(
75
+ fn=predict,
76
+ inputs=gr.Textbox(lines=2, placeholder="Type a sentence..."), # Updated line
77
+ outputs="text",
78
+ examples=[
79
+ "ఈ రోజు వాతావరణం చాలా బాగుంది.",
80
+ "సినిమా బాగుందా లేదా చెప్పు!",
81
+ "ఆలయ అధికారులు దర్శన ఏర్పాట్లు చేశారు."
82
+ ],
83
+ title="Rachana LLM"
84
+ )
85
+
86
+ iface.launch(share=True)