Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
from tokenizers import Tokenizer | |
import torch.nn.functional as F | |
import gradio as gr | |
# Enhanced Model with Custom Attention | |
class RachanaLLM(nn.Module): | |
def __init__(self, vocab_size=50000, embed_dim=284, num_heads=4, num_layers=4, dropout=0.2, max_len=256): | |
super(RachanaLLM, self).__init__() | |
self.embed_dim = embed_dim | |
self.embedding = nn.Embedding(vocab_size, embed_dim) | |
self.positional_encoding = nn.Parameter(torch.zeros(1, max_len, embed_dim)) | |
encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dropout=dropout, batch_first=True) | |
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) | |
self.output_layer = nn.Linear(embed_dim, vocab_size) | |
def forward(self, src, src_mask=None): | |
src = self.embedding(src) * torch.sqrt(torch.tensor(self.embed_dim, dtype=torch.float32)) | |
src = src + self.positional_encoding[:, :src.size(1)] | |
output = self.transformer_encoder(src, src_key_padding_mask=src_mask) | |
return self.output_layer(output) | |
# Load model and tokenizer | |
checkpoint = torch.load("best_model.pth", map_location=torch.device('cpu')) # Map location to CPU | |
model = RachanaLLM() | |
model.load_state_dict(checkpoint['model_state'] if 'model_state' in checkpoint else checkpoint) | |
model.eval() | |
tokenizer = Tokenizer.from_file("telugu_tokenizer_50k.json") # Use your actual path | |
# Advanced Beam Search with Temperature, Top-K Sampling, and Dynamic Repetition Penalty | |
def beam_search_decoder(model, tokenizer, input_ids, beam_width=3, max_length=20, temperature=1.0, top_k=5, repetition_penalty=1.2): | |
model.eval() | |
with torch.no_grad(): | |
sequences = [[input_ids, 0.0]] | |
for _ in range(max_length): | |
all_candidates = [] | |
for seq, score in sequences: | |
outputs = model(seq) | |
logits = outputs[:, -1] / temperature | |
# If a token has been used already, penalize its logit | |
for token_id in set(seq[0].tolist()): | |
if token_id in seq[0]: | |
logits[0][token_id] /= repetition_penalty | |
top_logits, top_indices = torch.topk(logits, top_k, dim=-1) | |
softmax_scores = F.log_softmax(top_logits, dim=-1) | |
for i in range(top_k): | |
next_token_id = top_indices[0][i].item() | |
if seq[0].tolist().count(next_token_id) < 2: # This ensures a token is penalized only if it appears more than once. | |
candidate = [torch.cat([seq, top_indices[:, i:i+1]], dim=1), score + softmax_scores[0][i].item()] | |
all_candidates.append(candidate) | |
# Sort candidates by score | |
ordered = sorted(all_candidates, key=lambda tup: tup[1], reverse=True) | |
sequences = ordered[:beam_width] | |
best_seq = sequences[0][0] | |
result_text = tokenizer.decode(best_seq[0].tolist()) | |
return result_text | |
# Predict function for Gradio Interface | |
def predict(input_sentence): | |
input_ids = torch.tensor(tokenizer.encode(input_sentence).ids).unsqueeze(0) | |
generated_text = beam_search_decoder(model, tokenizer, input_ids, temperature=0.8, top_k=10) | |
return generated_text | |
# Gradio Interface | |
iface = gr.Interface( | |
fn=predict, | |
inputs=gr.Textbox(lines=2, placeholder="Type a sentence..."), # Updated line | |
outputs="text", | |
examples=[ | |
"ఈ రోజు వాతావరణం చాలా బాగుంది.", | |
"సినిమా బాగుందా లేదా చెప్పు!", | |
"ఆలయ అధికారులు దర్శన ఏర్పాట్లు చేశారు." | |
], | |
title="Rachana LLM" | |
) | |
iface.launch(share=True) | |