|
import streamlit as st |
|
import torch |
|
import re |
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer |
|
|
|
|
|
model_name = "rohangbs/fine-tuned-gpt2" |
|
tokenizer = GPT2Tokenizer.from_pretrained(model_name) |
|
model = GPT2LMHeadModel.from_pretrained(model_name) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = model.to(device) |
|
|
|
|
|
def chat_with_model(input_prompt, max_length=200): |
|
model.eval() |
|
|
|
|
|
prompt = f"<|startoftext|>[WP] {input_prompt}\n[RESPONSE]" |
|
|
|
|
|
generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0).to(device) |
|
|
|
|
|
sample_outputs = model.generate( |
|
generated, |
|
do_sample=True, |
|
top_k=50, |
|
max_length=max_length, |
|
top_p=0.95, |
|
num_return_sequences=1, |
|
pad_token_id=tokenizer.eos_token_id |
|
) |
|
|
|
|
|
response_text = tokenizer.decode(sample_outputs[0], skip_special_tokens=True) |
|
wp_responses = re.split(r"\[WP\].*?\n|\[RESPONSE\]", response_text)[1:] |
|
clean_responses = [response.strip() for response in wp_responses if response.strip()] |
|
|
|
|
|
return clean_responses[0] if clean_responses else "I couldn't generate a response." |
|
|
|
|
|
st.title("Chatbot For Company Details") |
|
st.write("A GPT-2 model fine-tuned for Company dataset.") |
|
|
|
|
|
prompt = st.text_area("Ask your question:", height=150) |
|
|
|
if st.button("Send"): |
|
if prompt.strip(): |
|
with st.spinner("Generating..."): |
|
|
|
response = chat_with_model(prompt) |
|
st.subheader("Generated Response:") |
|
st.write(response) |
|
else: |
|
st.warning("Please enter a prompt.") |
|
|
|
|