import torch import torch.nn.functional as F from transformers import GPT2Tokenizer import gradio as gr # Load tokenizer tokenizer = GPT2Tokenizer.from_pretrained("gpt2") # Using GPT-2 tokenizer for compatibility # Load model from train_get2_8_init import GPT, GPTConfig device = 'cuda' if torch.cuda.is_available() else 'cpu' # Initialize the model config = GPTConfig() model = GPT(config) model.load_state_dict(torch.load("decoder_only_transformer.pth", map_location=torch.device(device))) model.eval() model.to(device) # Prediction function def generate_text(input_text, max_length=50, top_k=50): with torch.no_grad(): # Tokenize input tokens = tokenizer.encode(input_text, return_tensors="pt").to(device) x = tokens while x.size(1) < max_length: # Forward pass to get logits logits = model(x)[0] # (B, T, vocab_size) logits = logits[:, -1, :] # Take the logits at the last position # Get probabilities and do top-k sampling probs = F.softmax(logits, dim=-1) topk_probs, topk_indices = torch.topk(probs, top_k, dim=-1) ix = torch.multinomial(topk_probs, 1) # Sample token xcol = torch.gather(topk_indices, -1, ix) # Gather indices x = torch.cat((x, xcol), dim=1) # Append to sequence # Decode tokens into text generated_text = tokenizer.decode(x[0]) return generated_text # Gradio Interface def gradio_interface(input_text): return generate_text(input_text) interface = gr.Interface( fn=gradio_interface, inputs=gr.Textbox(lines=2, placeholder="Enter your text here..."), outputs=gr.Textbox(lines=2, placeholder="Generated text will appear here..."), title="Text Prediction App", description="Enter a text prompt to generate the next sequence of words.", ) # Launch the app interface.launch()