pdfcsvdatarag / app.py
Shankarm08's picture
Update app.py
93a3da9 verified
raw
history blame
2.65 kB
import streamlit as st
import torch
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
from datasets import load_dataset
import pandas as pd
import pdfplumber
# Load RAG model, tokenizer, and retriever
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", use_dummy_dataset=True)
model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever)
# Function to get RAG embeddings
def get_rag_embeddings(question, context):
inputs = tokenizer(question, context, return_tensors="pt", truncation=True)
with torch.no_grad():
output = model.generate(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
return tokenizer.batch_decode(output, skip_special_tokens=True)[0]
# Extract text from PDF
def extract_text_from_pdf(pdf_file):
with pdfplumber.open(pdf_file) as pdf:
text = ""
for page in pdf.pages:
page_text = page.extract_text()
if page_text: # Check if the page has extractable text
text += page_text + "\n"
return text.strip() # Return stripped text for better formatting
# Store the PDF text and embeddings
pdf_text = ""
csv_data = None
# Streamlit app UI
st.title("RAG-Powered PDF & CSV Chatbot")
# CSV file upload
csv_file = st.file_uploader("Upload a CSV file", type=["csv"])
if csv_file:
csv_data = pd.read_csv(csv_file)
st.write("CSV file loaded successfully!")
st.write(csv_data)
# PDF file upload
pdf_file = st.file_uploader("Upload a PDF file", type=["pdf"])
if pdf_file:
pdf_text = extract_text_from_pdf(pdf_file)
if pdf_text:
st.success("PDF loaded successfully!")
st.text_area("Extracted Text from PDF", pdf_text, height=200)
else:
st.warning("No extractable text found in the PDF.")
# User input for chatbot
user_input = st.text_input("Ask a question related to the PDF or CSV:")
# Get response on button click
if st.button("Get Response"):
if not pdf_text and csv_data is None:
st.warning("Please upload a PDF or CSV file first.")
else:
# Combine PDF text and CSV content for context in RAG
combined_context = pdf_text
if csv_data is not None:
combined_context += "\n" + csv_data.to_string()
# Get RAG-generated response
try:
response = get_rag_embeddings(user_input, combined_context)
st.write("### Response:")
st.write(response)
except Exception as e:
st.error(f"Error while processing the question: {e}")