Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration | |
import pandas as pd | |
import pdfplumber | |
# Load the RAG model and tokenizer | |
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") | |
retriever = RagRetriever.from_pretrained("facebook/wiki_dpr", use_dummy_dataset=True) | |
model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever) | |
# Function to extract text from a PDF file | |
def extract_text_from_pdf(pdf_file): | |
text = "" | |
with pdfplumber.open(pdf_file) as pdf: | |
for page in pdf.pages: | |
page_text = page.extract_text() | |
if page_text: | |
text += page_text + "\n" | |
return text.strip() | |
# Streamlit app | |
st.title("RAG-Powered PDF & CSV Chatbot") | |
# CSV file upload | |
csv_file = st.file_uploader("Upload a CSV file", type=["csv"]) | |
csv_data = None | |
if csv_file: | |
csv_data = pd.read_csv(csv_file) | |
st.write("CSV file loaded successfully!") | |
st.write(csv_data) | |
# PDF file upload | |
pdf_file = st.file_uploader("Upload a PDF file", type=["pdf"]) | |
pdf_text = "" | |
if pdf_file: | |
pdf_text = extract_text_from_pdf(pdf_file) | |
if pdf_text: | |
st.success("PDF loaded successfully!") | |
st.text_area("Extracted Text from PDF", pdf_text, height=200) | |
else: | |
st.warning("No extractable text found in the PDF.") | |
# User input for chatbot | |
user_input = st.text_input("Ask a question related to the PDF or CSV:") | |
# Get response on button click | |
if st.button("Get Response"): | |
if not pdf_text and csv_data is None: | |
st.warning("Please upload a PDF or CSV file first.") | |
else: | |
combined_context = pdf_text | |
if csv_data is not None: | |
combined_context += "\n" + csv_data.to_string() | |
# Generate response using RAG | |
inputs = tokenizer(user_input, return_tensors="pt", truncation=True) | |
with torch.no_grad(): | |
output = model.generate(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask']) | |
response = tokenizer.batch_decode(output, skip_special_tokens=True)[0] | |
st.write("### Response:") | |
st.write(response) | |