pdfcsvdatarag / app.py
Shankarm08's picture
Create app.py
663e818 verified
raw
history blame
4.23 kB
import streamlit as st
import pandas as pd
import torch
import faiss
import numpy as np
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration
import pdfplumber
import pytesseract
from sklearn.metrics.pairwise import cosine_similarity
# Load the RAG tokenizer and model
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True)
model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq")
# Function to get embeddings for FAISS index
def get_faiss_index(data_chunks):
embeddings = [retriever.question_encoder_tokenizer(chunk, return_tensors="pt").input_ids for chunk in data_chunks]
embeddings = torch.cat(embeddings).detach().numpy()
# Build FAISS index
index = faiss.IndexFlatL2(embeddings.shape[1]) # L2 distance
index.add(embeddings)
return index, embeddings
# Extract text and tables from PDF (with OCR fallback)
def extract_text_from_pdf(pdf_file):
text = ""
with pdfplumber.open(pdf_file) as pdf:
for page_num, page in enumerate(pdf.pages, 1):
page_text = page.extract_text()
if page_text:
text += page_text + "\n"
else:
st.warning(f"No extractable text found on page {page_num}. Using OCR...")
page_image = page.to_image().original
ocr_text = pytesseract.image_to_string(page_image)
if ocr_text.strip():
text += ocr_text + "\n"
else:
st.error(f"Even OCR couldn't extract text from page {page_num}.")
return text
# Function to process input for RAG model
def generate_rag_response(user_input, data_chunks):
inputs = tokenizer([user_input], return_tensors="pt")
retrieved_docs = retriever(input_ids=inputs['input_ids'], n_docs=5)
outputs = model.generate(input_ids=inputs['input_ids'], context_input_ids=retrieved_docs['context_input_ids'])
return tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
# Streamlit app
st.title("CSV and PDF Chatbot with RAG")
# CSV file upload
csv_file = st.file_uploader("Upload a CSV file", type=["csv"])
csv_data = None
if csv_file:
csv_data = pd.read_csv(csv_file)
st.success("CSV loaded successfully!")
st.write("### CSV Data:")
st.write(csv_data)
# PDF file upload
pdf_file = st.file_uploader("Upload a PDF file", type=["pdf"])
pdf_text = ""
data_chunks = []
if pdf_file:
pdf_text = extract_text_from_pdf(pdf_file)
if not pdf_text.strip():
st.error("The extracted PDF text is empty. Please upload a PDF with extractable text.")
else:
st.success("PDF loaded successfully!")
st.write("### Extracted Text:")
st.write(pdf_text)
# Split the extracted text into chunks for FAISS
data_chunks = pdf_text.split('\n')
st.write("### Extracted Chunks:")
for chunk in data_chunks[:5]: # Display first 5 chunks
st.write(chunk)
# User input for chatbot
user_input = st.text_input("Ask a question about the CSV or PDF:")
if st.button("Get Response"):
if csv_data is None and not data_chunks:
st.warning("Please upload both a CSV and PDF file first.")
elif not user_input.strip():
st.warning("Please enter a question.")
else:
try:
if csv_data is not None:
# Check if the query is related to CSV content
csv_response = csv_data[csv_data.apply(lambda row: row.astype(str).str.contains(user_input, case=False).any(), axis=1)]
if not csv_response.empty:
st.write("### CSV Response:")
st.write(csv_response)
else:
st.write("No relevant data found in the CSV.")
if data_chunks:
# Generate response using RAG for PDF content
response = generate_rag_response(user_input, data_chunks)
st.write("### PDF Response:")
st.write(response)
except Exception as e:
st.error(f"Error while processing user input: {e}")