File size: 1,987 Bytes
e482461 970a3e5 acbe616 e482461 acbe616 970a3e5 e482461 d35fc36 e482461 acbe616 e482461 d35fc36 e482461 d35fc36 e482461 970a3e5 e482461 acbe616 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import streamlit as st
import torch
import re
from transformers import GPT2LMHeadModel, GPT2Tokenizer
# Load the fine-tuned model and tokenizer
model_name = "rohangbs/fine-tuned-gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
# Ensure the model is on the correct device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# Function to generate a response
def chat_with_model(input_prompt, max_length=200):
model.eval()
# Format the input prompt with special tokens
prompt = f"<|startoftext|>[WP] {input_prompt}\n[RESPONSE]"
# Tokenize and encode the prompt, and send to the device
generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0).to(device)
# Generate a response
sample_outputs = model.generate(
generated,
do_sample=True,
top_k=50,
max_length=max_length,
top_p=0.95,
num_return_sequences=1,
pad_token_id=tokenizer.eos_token_id
)
# Decode the response and clean it up
response_text = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
wp_responses = re.split(r"\[WP\].*?\n|\[RESPONSE\]", response_text)[1:]
clean_responses = [response.strip() for response in wp_responses if response.strip()]
# Return the first valid response
return clean_responses[0] if clean_responses else "I couldn't generate a response."
# Streamlit UI
st.title("Chatbot For Company Details")
st.write("A GPT-2 model fine-tuned for Company dataset.")
# User input
prompt = st.text_area("Ask your question:", height=150)
if st.button("Send"):
if prompt.strip():
with st.spinner("Generating..."):
# Generate and display the response
response = chat_with_model(prompt)
st.subheader("Generated Response:")
st.write(response)
else:
st.warning("Please enter a prompt.")
|