Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration | |
from datasets import load_dataset | |
import pandas as pd | |
import pdfplumber | |
# Load RAG model, tokenizer, and retriever | |
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") | |
retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", use_dummy_dataset=True) | |
model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever) | |
# Function to get RAG embeddings | |
def get_rag_embeddings(question, context): | |
inputs = tokenizer(question, context, return_tensors="pt", truncation=True) | |
with torch.no_grad(): | |
output = model.generate(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask']) | |
return tokenizer.batch_decode(output, skip_special_tokens=True)[0] | |
# Extract text from PDF | |
def extract_text_from_pdf(pdf_file): | |
with pdfplumber.open(pdf_file) as pdf: | |
text = "" | |
for page in pdf.pages: | |
page_text = page.extract_text() | |
if page_text: # Check if the page has extractable text | |
text += page_text + "\n" | |
return text.strip() # Return stripped text for better formatting | |
# Store the PDF text and embeddings | |
pdf_text = "" | |
csv_data = None | |
# Streamlit app UI | |
st.title("RAG-Powered PDF & CSV Chatbot") | |
# CSV file upload | |
csv_file = st.file_uploader("Upload a CSV file", type=["csv"]) | |
if csv_file: | |
csv_data = pd.read_csv(csv_file) | |
st.write("CSV file loaded successfully!") | |
st.write(csv_data) | |
# PDF file upload | |
pdf_file = st.file_uploader("Upload a PDF file", type=["pdf"]) | |
if pdf_file: | |
pdf_text = extract_text_from_pdf(pdf_file) | |
if pdf_text: | |
st.success("PDF loaded successfully!") | |
st.text_area("Extracted Text from PDF", pdf_text, height=200) | |
else: | |
st.warning("No extractable text found in the PDF.") | |
# User input for chatbot | |
user_input = st.text_input("Ask a question related to the PDF or CSV:") | |
# Get response on button click | |
if st.button("Get Response"): | |
if not pdf_text and csv_data is None: | |
st.warning("Please upload a PDF or CSV file first.") | |
else: | |
# Combine PDF text and CSV content for context in RAG | |
combined_context = pdf_text | |
if csv_data is not None: | |
combined_context += "\n" + csv_data.to_string() | |
# Get RAG-generated response | |
try: | |
response = get_rag_embeddings(user_input, combined_context) | |
st.write("### Response:") | |
st.write(response) | |
except Exception as e: | |
st.error(f"Error while processing the question: {e}") | |