import torch from torch import nn from tokenizers import Tokenizer import torch.nn.functional as F import gradio as gr # Enhanced Model with Custom Attention class RachanaLLM(nn.Module): def __init__(self, vocab_size=50000, embed_dim=284, num_heads=4, num_layers=4, dropout=0.2, max_len=256): super(RachanaLLM, self).__init__() self.embed_dim = embed_dim self.embedding = nn.Embedding(vocab_size, embed_dim) self.positional_encoding = nn.Parameter(torch.zeros(1, max_len, embed_dim)) encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dropout=dropout, batch_first=True) self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) self.output_layer = nn.Linear(embed_dim, vocab_size) def forward(self, src, src_mask=None): src = self.embedding(src) * torch.sqrt(torch.tensor(self.embed_dim, dtype=torch.float32)) src = src + self.positional_encoding[:, :src.size(1)] output = self.transformer_encoder(src, src_key_padding_mask=src_mask) return self.output_layer(output) # Load model and tokenizer checkpoint = torch.load("best_model.pth") # Use your actual path model = RachanaLLM() model.load_state_dict(checkpoint['model_state'] if 'model_state' in checkpoint else checkpoint) model.eval() tokenizer = Tokenizer.from_file("telugu_tokenizer_50k.json") # Use your actual path # Advanced Beam Search with Temperature, Top-K Sampling, and Dynamic Repetition Penalty def beam_search_decoder(model, tokenizer, input_ids, beam_width=3, max_length=20, temperature=1.0, top_k=5, repetition_penalty=1.2): model.eval() with torch.no_grad(): sequences = [[input_ids, 0.0]] for _ in range(max_length): all_candidates = [] for seq, score in sequences: outputs = model(seq) logits = outputs[:, -1] / temperature # If a token has been used already, penalize its logit for token_id in set(seq[0].tolist()): if token_id in seq[0]: logits[0][token_id] /= repetition_penalty top_logits, top_indices = torch.topk(logits, top_k, dim=-1) softmax_scores = F.log_softmax(top_logits, dim=-1) for i in range(top_k): next_token_id = top_indices[0][i].item() if seq[0].tolist().count(next_token_id) < 2: # This ensures a token is penalized only if it appears more than once. candidate = [torch.cat([seq, top_indices[:, i:i+1]], dim=1), score + softmax_scores[0][i].item()] all_candidates.append(candidate) # Sort candidates by score ordered = sorted(all_candidates, key=lambda tup: tup[1], reverse=True) sequences = ordered[:beam_width] best_seq = sequences[0][0] result_text = tokenizer.decode(best_seq[0].tolist()) return result_text # Predict function for Gradio Interface def predict(input_sentence): input_ids = torch.tensor(tokenizer.encode(input_sentence).ids).unsqueeze(0) generated_text = beam_search_decoder(model, tokenizer, input_ids, temperature=0.8, top_k=10) return generated_text # Gradio Interface iface = gr.Interface( fn=predict, inputs=gr.Textbox(lines=2, placeholder="Type a sentence..."), # Updated line outputs="text", examples=[ "ఈ రోజు వాతావరణం చాలా బాగుంది.", "సినిమా బాగుందా లేదా చెప్పు!", "ఆలయ అధికారులు దర్శన ఏర్పాట్లు చేశారు." ], title="Rachana LLM" ) iface.launch(share=True)