GPT2 / app.py
rohangbs's picture
Update app.py
970a3e5 verified
import streamlit as st
import torch
import re
from transformers import GPT2LMHeadModel, GPT2Tokenizer
# Load the fine-tuned model and tokenizer
model_name = "rohangbs/fine-tuned-gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
# Ensure the model is on the correct device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# Function to generate a response
def chat_with_model(input_prompt, max_length=200):
model.eval()
# Format the input prompt with special tokens
prompt = f"<|startoftext|>[WP] {input_prompt}\n[RESPONSE]"
# Tokenize and encode the prompt, and send to the device
generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0).to(device)
# Generate a response
sample_outputs = model.generate(
generated,
do_sample=True,
top_k=50,
max_length=max_length,
top_p=0.95,
num_return_sequences=1,
pad_token_id=tokenizer.eos_token_id
)
# Decode the response and clean it up
response_text = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
wp_responses = re.split(r"\[WP\].*?\n|\[RESPONSE\]", response_text)[1:]
clean_responses = [response.strip() for response in wp_responses if response.strip()]
# Return the first valid response
return clean_responses[0] if clean_responses else "I couldn't generate a response."
# Streamlit UI
st.title("Chatbot For Company Details")
st.write("A GPT-2 model fine-tuned for Company dataset.")
# User input
prompt = st.text_area("Ask your question:", height=150)
if st.button("Send"):
if prompt.strip():
with st.spinner("Generating..."):
# Generate and display the response
response = chat_with_model(prompt)
st.subheader("Generated Response:")
st.write(response)
else:
st.warning("Please enter a prompt.")