Spaces:
Sleeping
Sleeping
vineelnani16
commited on
Commit
•
2656ff7
1
Parent(s):
28e8d69
app.py
Browse filesThis file contains code to
app.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from tokenizers import Tokenizer
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import gradio as gr
|
6 |
+
|
7 |
+
# Enhanced Model with Custom Attention
|
8 |
+
class RachanaLLM(nn.Module):
|
9 |
+
def __init__(self, vocab_size=50000, embed_dim=284, num_heads=4, num_layers=4, dropout=0.2, max_len=256):
|
10 |
+
super(RachanaLLM, self).__init__()
|
11 |
+
self.embed_dim = embed_dim
|
12 |
+
self.embedding = nn.Embedding(vocab_size, embed_dim)
|
13 |
+
self.positional_encoding = nn.Parameter(torch.zeros(1, max_len, embed_dim))
|
14 |
+
encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dropout=dropout, batch_first=True)
|
15 |
+
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
16 |
+
self.output_layer = nn.Linear(embed_dim, vocab_size)
|
17 |
+
|
18 |
+
def forward(self, src, src_mask=None):
|
19 |
+
src = self.embedding(src) * torch.sqrt(torch.tensor(self.embed_dim, dtype=torch.float32))
|
20 |
+
src = src + self.positional_encoding[:, :src.size(1)]
|
21 |
+
output = self.transformer_encoder(src, src_key_padding_mask=src_mask)
|
22 |
+
return self.output_layer(output)
|
23 |
+
|
24 |
+
# Load model and tokenizer
|
25 |
+
checkpoint = torch.load("best_model.pth") # Use your actual path
|
26 |
+
model = RachanaLLM()
|
27 |
+
model.load_state_dict(checkpoint['model_state'] if 'model_state' in checkpoint else checkpoint)
|
28 |
+
model.eval()
|
29 |
+
tokenizer = Tokenizer.from_file("telugu_tokenizer_50k.json") # Use your actual path
|
30 |
+
|
31 |
+
# Advanced Beam Search with Temperature, Top-K Sampling, and Dynamic Repetition Penalty
|
32 |
+
def beam_search_decoder(model, tokenizer, input_ids, beam_width=3, max_length=20, temperature=1.0, top_k=5, repetition_penalty=1.2):
|
33 |
+
model.eval()
|
34 |
+
with torch.no_grad():
|
35 |
+
sequences = [[input_ids, 0.0]]
|
36 |
+
|
37 |
+
for _ in range(max_length):
|
38 |
+
all_candidates = []
|
39 |
+
for seq, score in sequences:
|
40 |
+
outputs = model(seq)
|
41 |
+
logits = outputs[:, -1] / temperature
|
42 |
+
|
43 |
+
# If a token has been used already, penalize its logit
|
44 |
+
for token_id in set(seq[0].tolist()):
|
45 |
+
if token_id in seq[0]:
|
46 |
+
logits[0][token_id] /= repetition_penalty
|
47 |
+
|
48 |
+
top_logits, top_indices = torch.topk(logits, top_k, dim=-1)
|
49 |
+
softmax_scores = F.log_softmax(top_logits, dim=-1)
|
50 |
+
|
51 |
+
for i in range(top_k):
|
52 |
+
next_token_id = top_indices[0][i].item()
|
53 |
+
if seq[0].tolist().count(next_token_id) < 2: # This ensures a token is penalized only if it appears more than once.
|
54 |
+
candidate = [torch.cat([seq, top_indices[:, i:i+1]], dim=1), score + softmax_scores[0][i].item()]
|
55 |
+
all_candidates.append(candidate)
|
56 |
+
|
57 |
+
# Sort candidates by score
|
58 |
+
ordered = sorted(all_candidates, key=lambda tup: tup[1], reverse=True)
|
59 |
+
sequences = ordered[:beam_width]
|
60 |
+
|
61 |
+
best_seq = sequences[0][0]
|
62 |
+
result_text = tokenizer.decode(best_seq[0].tolist())
|
63 |
+
return result_text
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
# Predict function for Gradio Interface
|
68 |
+
def predict(input_sentence):
|
69 |
+
input_ids = torch.tensor(tokenizer.encode(input_sentence).ids).unsqueeze(0)
|
70 |
+
generated_text = beam_search_decoder(model, tokenizer, input_ids, temperature=0.8, top_k=10)
|
71 |
+
return generated_text
|
72 |
+
|
73 |
+
# Gradio Interface
|
74 |
+
iface = gr.Interface(
|
75 |
+
fn=predict,
|
76 |
+
inputs=gr.Textbox(lines=2, placeholder="Type a sentence..."), # Updated line
|
77 |
+
outputs="text",
|
78 |
+
examples=[
|
79 |
+
"ఈ రోజు వాతావరణం చాలా బాగుంది.",
|
80 |
+
"సినిమా బాగుందా లేదా చెప్పు!",
|
81 |
+
"ఆలయ అధికారులు దర్శన ఏర్పాట్లు చేశారు."
|
82 |
+
],
|
83 |
+
title="Rachana LLM"
|
84 |
+
)
|
85 |
+
|
86 |
+
iface.launch(share=True)
|