import streamlit as st import torch import re from transformers import GPT2LMHeadModel, GPT2Tokenizer # Load the fine-tuned model and tokenizer model_name = "rohangbs/fine-tuned-gpt2" tokenizer = GPT2Tokenizer.from_pretrained(model_name) model = GPT2LMHeadModel.from_pretrained(model_name) # Ensure the model is on the correct device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) # Function to generate a response def chat_with_model(input_prompt, max_length=200): model.eval() # Format the input prompt with special tokens prompt = f"<|startoftext|>[WP] {input_prompt}\n[RESPONSE]" # Tokenize and encode the prompt, and send to the device generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0).to(device) # Generate a response sample_outputs = model.generate( generated, do_sample=True, top_k=50, max_length=max_length, top_p=0.95, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id ) # Decode the response and clean it up response_text = tokenizer.decode(sample_outputs[0], skip_special_tokens=True) wp_responses = re.split(r"\[WP\].*?\n|\[RESPONSE\]", response_text)[1:] clean_responses = [response.strip() for response in wp_responses if response.strip()] # Return the first valid response return clean_responses[0] if clean_responses else "I couldn't generate a response." # Streamlit UI st.title("Chatbot For Company Details") st.write("A GPT-2 model fine-tuned for Company dataset.") # User input prompt = st.text_area("Ask your question:", height=150) if st.button("Send"): if prompt.strip(): with st.spinner("Generating..."): # Generate and display the response response = chat_with_model(prompt) st.subheader("Generated Response:") st.write(response) else: st.warning("Please enter a prompt.")