Spaces:
Running
Running
import torch | |
import torch.nn.functional as F | |
from transformers import GPT2Tokenizer | |
import gradio as gr | |
# Load tokenizer | |
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") # Using GPT-2 tokenizer for compatibility | |
# Load model | |
from train_get2_8_init import GPT, GPTConfig | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
# Initialize the model | |
config = GPTConfig() | |
model = GPT(config) | |
model.load_state_dict(torch.load("decoder_only_transformer.pth", map_location=torch.device(device))) | |
model.eval() | |
model.to(device) | |
# Prediction function | |
def generate_text(input_text, max_length=50, top_k=50): | |
with torch.no_grad(): | |
# Tokenize input | |
tokens = tokenizer.encode(input_text, return_tensors="pt").to(device) | |
x = tokens | |
while x.size(1) < max_length: | |
# Forward pass to get logits | |
logits = model(x)[0] # (B, T, vocab_size) | |
logits = logits[:, -1, :] # Take the logits at the last position | |
# Get probabilities and do top-k sampling | |
probs = F.softmax(logits, dim=-1) | |
topk_probs, topk_indices = torch.topk(probs, top_k, dim=-1) | |
ix = torch.multinomial(topk_probs, 1) # Sample token | |
xcol = torch.gather(topk_indices, -1, ix) # Gather indices | |
x = torch.cat((x, xcol), dim=1) # Append to sequence | |
# Decode tokens into text | |
generated_text = tokenizer.decode(x[0]) | |
return generated_text | |
# Gradio Interface | |
def gradio_interface(input_text): | |
return generate_text(input_text) | |
interface = gr.Interface( | |
fn=gradio_interface, | |
inputs=gr.Textbox(lines=2, placeholder="Enter your text here..."), | |
outputs=gr.Textbox(lines=2, placeholder="Generated text will appear here..."), | |
title="Text Prediction App", | |
description="Enter a text prompt to generate the next sequence of words.", | |
) | |
# Launch the app | |
interface.launch() | |