rachana_llm / app.py
vineelnani16's picture
Update app.py
280355d
raw
history blame
3.9 kB
import torch
from torch import nn
from tokenizers import Tokenizer
import torch.nn.functional as F
import gradio as gr
# Enhanced Model with Custom Attention
class RachanaLLM(nn.Module):
def __init__(self, vocab_size=50000, embed_dim=284, num_heads=4, num_layers=4, dropout=0.2, max_len=256):
super(RachanaLLM, self).__init__()
self.embed_dim = embed_dim
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.positional_encoding = nn.Parameter(torch.zeros(1, max_len, embed_dim))
encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dropout=dropout, batch_first=True)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.output_layer = nn.Linear(embed_dim, vocab_size)
def forward(self, src, src_mask=None):
src = self.embedding(src) * torch.sqrt(torch.tensor(self.embed_dim, dtype=torch.float32))
src = src + self.positional_encoding[:, :src.size(1)]
output = self.transformer_encoder(src, src_key_padding_mask=src_mask)
return self.output_layer(output)
# Load model and tokenizer
checkpoint = torch.load("best_model.pth", map_location=torch.device('cpu')) # Map location to CPU
model = RachanaLLM()
model.load_state_dict(checkpoint['model_state'] if 'model_state' in checkpoint else checkpoint)
model.eval()
tokenizer = Tokenizer.from_file("telugu_tokenizer_50k.json") # Use your actual path
# Advanced Beam Search with Temperature, Top-K Sampling, and Dynamic Repetition Penalty
def beam_search_decoder(model, tokenizer, input_ids, beam_width=3, max_length=20, temperature=1.0, top_k=5, repetition_penalty=1.2):
model.eval()
with torch.no_grad():
sequences = [[input_ids, 0.0]]
for _ in range(max_length):
all_candidates = []
for seq, score in sequences:
outputs = model(seq)
logits = outputs[:, -1] / temperature
# If a token has been used already, penalize its logit
for token_id in set(seq[0].tolist()):
if token_id in seq[0]:
logits[0][token_id] /= repetition_penalty
top_logits, top_indices = torch.topk(logits, top_k, dim=-1)
softmax_scores = F.log_softmax(top_logits, dim=-1)
for i in range(top_k):
next_token_id = top_indices[0][i].item()
if seq[0].tolist().count(next_token_id) < 2: # This ensures a token is penalized only if it appears more than once.
candidate = [torch.cat([seq, top_indices[:, i:i+1]], dim=1), score + softmax_scores[0][i].item()]
all_candidates.append(candidate)
# Sort candidates by score
ordered = sorted(all_candidates, key=lambda tup: tup[1], reverse=True)
sequences = ordered[:beam_width]
best_seq = sequences[0][0]
result_text = tokenizer.decode(best_seq[0].tolist())
return result_text
# Predict function for Gradio Interface
def predict(input_sentence):
input_ids = torch.tensor(tokenizer.encode(input_sentence).ids).unsqueeze(0)
generated_text = beam_search_decoder(model, tokenizer, input_ids, temperature=0.8, top_k=10)
return generated_text
# Gradio Interface
iface = gr.Interface(
fn=predict,
inputs=gr.Textbox(lines=2, placeholder="Type a sentence..."), # Updated line
outputs="text",
examples=[
"ఈ రోజు వాతావరణం చాలా బాగుంది.",
"సినిమా బాగుందా లేదా చెప్పు!",
"ఆలయ అధికారులు దర్శన ఏర్పాట్లు చేశారు."
],
title="Rachana LLM"
)
iface.launch(share=True)