NextWordGPT / app.py
chbsaikiran's picture
Initial commit with model trained with loss less than 0.099999
49c48e3
raw
history blame
1.92 kB
import torch
import torch.nn.functional as F
from transformers import GPT2Tokenizer
import gradio as gr
# Load tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") # Using GPT-2 tokenizer for compatibility
# Load model
from train_get2_8_init import GPT, GPTConfig
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Initialize the model
config = GPTConfig()
model = GPT(config)
model.load_state_dict(torch.load("decoder_only_transformer.pth", map_location=torch.device(device)))
model.eval()
model.to(device)
# Prediction function
def generate_text(input_text, max_length=50, top_k=50):
with torch.no_grad():
# Tokenize input
tokens = tokenizer.encode(input_text, return_tensors="pt").to(device)
x = tokens
while x.size(1) < max_length:
# Forward pass to get logits
logits = model(x)[0] # (B, T, vocab_size)
logits = logits[:, -1, :] # Take the logits at the last position
# Get probabilities and do top-k sampling
probs = F.softmax(logits, dim=-1)
topk_probs, topk_indices = torch.topk(probs, top_k, dim=-1)
ix = torch.multinomial(topk_probs, 1) # Sample token
xcol = torch.gather(topk_indices, -1, ix) # Gather indices
x = torch.cat((x, xcol), dim=1) # Append to sequence
# Decode tokens into text
generated_text = tokenizer.decode(x[0])
return generated_text
# Gradio Interface
def gradio_interface(input_text):
return generate_text(input_text)
interface = gr.Interface(
fn=gradio_interface,
inputs=gr.Textbox(lines=2, placeholder="Enter your text here..."),
outputs=gr.Textbox(lines=2, placeholder="Generated text will appear here..."),
title="Text Prediction App",
description="Enter a text prompt to generate the next sequence of words.",
)
# Launch the app
interface.launch()