zliang's picture
Update app.py
a2d9aa7 verified
import os
import time
import io
import base64
import re
import numpy as np
import fitz # PyMuPDF
import tempfile
from PIL import Image
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import cosine_similarity
from ultralytics import YOLO
import streamlit as st
from streamlit_chat import message
from langchain_core.output_parsers import StrOutputParser
from langchain_community.document_loaders import PyMuPDFLoader
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_text_splitters import SpacyTextSplitter
from langchain_core.prompts import ChatPromptTemplate
from streamlit.runtime.scriptrunner import get_script_run_ctx
from streamlit import runtime
# Initialize models and environment
os.system("python -m spacy download en_core_web_sm")
model = YOLO("best.pt")
openai_api_key = os.environ.get("openai_api_key")
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
# Utility functions
@st.cache_data(show_spinner=False, ttl=3600)
def clean_text(text):
return re.sub(r'\s+', ' ', text).strip()
def remove_references(text):
reference_patterns = [
r'\bReferences\b', r'\breferences\b', r'\bBibliography\b',
r'\bCitations\b', r'\bWorks Cited\b', r'\bReference\b'
]
lines = text.split('\n')
for i, line in enumerate(lines):
if any(re.search(pattern, line, re.IGNORECASE) for pattern in reference_patterns):
return '\n'.join(lines[:i])
return text
def handle_errors(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
st.session_state.chat_history.append({
"bot": f"❌ An error occurred: {str(e)}"
})
st.rerun()
return wrapper
def show_progress(message):
progress_bar = st.progress(0)
status_text = st.empty()
for i in range(100):
time.sleep(0.02)
progress_bar.progress(i + 1)
status_text.text(f"{message}... {i+1}%")
progress_bar.empty()
status_text.empty()
def scroll_to_bottom():
ctx = get_script_run_ctx()
if ctx and runtime.exists():
js = """
<script>
function scrollToBottom() {
window.parent.document.querySelector('section.main').scrollTo(0, window.parent.document.querySelector('section.main').scrollHeight);
}
setTimeout(scrollToBottom, 100);
</script>
"""
st.components.v1.html(js, height=0)
# ----------------------------
# Core Processing Functions
# ----------------------------
@st.cache_data(show_spinner=False, ttl=3600)
@handle_errors
@st.cache_data(show_spinner=False, ttl=3600)
@handle_errors
def summarize_pdf_with_tooltips(_pdf_file_path, num_clusters=10):
"""
Generates a summary with in-text citations that display the full excerpt as a tooltip on hover.
Each citation is embedded as an HTML span element with the tooltip text.
"""
embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small", api_key=openai_api_key)
llm = ChatOpenAI(model="gpt-3.5-turbo", api_key=openai_api_key, temperature=0.3)
prompt = ChatPromptTemplate.from_template(
"""Generate a comprehensive summary that includes the following:
1. Key findings and conclusions
2. Main methodologies used
3. Important data points
4. Limitations mentioned
For any information directly derived from the context excerpts provided below, insert an in-text citation as an HTML tooltip.
For each citation, use the following HTML format:
<span class="tooltip" data-tooltip="{full_text}">[{n}]</span>
Where:
- {n} is the citation number.
- {full_text} is the complete excerpt text for that citation.
Do not provide a separate reference list. Instead, embed the full citation text directly in the tooltip.
Context Excerpts:
{contexts}"""
)
loader = PyMuPDFLoader(_pdf_file_path)
docs = loader.load()
full_text = "\n".join(doc.page_content for doc in docs)
cleaned_full_text = clean_text(remove_references(full_text))
text_splitter = SpacyTextSplitter(chunk_size=500)
split_contents = text_splitter.split_text(cleaned_full_text)
embeddings = embeddings_model.embed_documents(split_contents)
kmeans = KMeans(n_clusters=num_clusters, random_state=0).fit(embeddings)
citation_indices = []
for center in kmeans.cluster_centers_:
distances = np.linalg.norm(embeddings - center, axis=1)
idx = int(np.argmin(distances))
citation_indices.append(idx)
# Build the context excerpts string.
citation_contexts = []
for i, idx in enumerate(citation_indices):
# Replace double quotes to avoid breaking HTML attribute quotes.
excerpt = split_contents[idx].replace('"', "'")
citation_contexts.append(f"[{i+1}]: {excerpt}")
combined_contexts = "\n\n".join(citation_contexts)
chain = prompt | llm | StrOutputParser()
result = chain.invoke({"contexts": combined_contexts})
return result
@st.cache_data(show_spinner=False, ttl=3600)
@handle_errors
def qa_pdf(_pdf_file_path, query, num_clusters=5):
embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small", api_key=openai_api_key)
llm = ChatOpenAI(model="gpt-4", api_key=openai_api_key, temperature=0.3)
prompt = ChatPromptTemplate.from_template(
"""Answer this question: {question}
Using only this context: {context}
Format your answer with:
- Clear section headings
- Bullet points for lists
- **Bold** key terms
- Citations from the text"""
)
loader = PyMuPDFLoader(_pdf_file_path)
docs = loader.load()
full_text = "\n".join(doc.page_content for doc in docs)
cleaned_full_text = clean_text(remove_references(full_text))
text_splitter = SpacyTextSplitter(chunk_size=500)
split_contents = text_splitter.split_text(cleaned_full_text)
query_embedding = embeddings_model.embed_query(query)
similarities = cosine_similarity([query_embedding],
embeddings_model.embed_documents(split_contents))[0]
top_indices = np.argsort(similarities)[-num_clusters:]
chain = prompt | llm | StrOutputParser()
return chain.invoke({
"question": query,
"context": ' '.join([split_contents[i] for i in top_indices])
})
@st.cache_data(show_spinner=False, ttl=3600)
@handle_errors
def process_pdf(_pdf_file_path):
doc = fitz.open(_pdf_file_path)
all_figures, all_tables = [], []
scale_factor = 300 / 50 # High-res to low-res ratio
for page in doc:
low_res = page.get_pixmap(dpi=50)
low_res_img = np.frombuffer(low_res.samples, dtype=np.uint8).reshape(low_res.height, low_res.width, 3)
results = model.predict(low_res_img)
boxes = [
(int(box.xyxy[0][0]), int(box.xyxy[0][1]),
int(box.xyxy[0][2]), int(box.xyxy[0][3]), int(box.cls[0]))
for result in results for box in result.boxes
if box.conf[0] > 0.8 and int(box.cls[0]) in {3, 4}
]
if boxes:
high_res = page.get_pixmap(dpi=300)
high_res_img = np.frombuffer(high_res.samples, dtype=np.uint8).reshape(high_res.height, high_res.width, 3)
for (x1, y1, x2, y2, cls) in boxes:
cropped = high_res_img[int(y1*scale_factor):int(y2*scale_factor),
int(x1*scale_factor):int(x2*scale_factor)]
if cls == 4:
all_figures.append(cropped)
else:
all_tables.append(cropped)
return all_figures, all_tables
def image_to_base64(img):
buffered = io.BytesIO()
img = Image.fromarray(img).convert("RGB")
img.thumbnail((800, 800)) # Optimize image size
img.save(buffered, format="JPEG", quality=85)
return base64.b64encode(buffered.getvalue()).decode()
# ----------------------------
# Streamlit UI Setup
# ----------------------------
st.set_page_config(
page_title="PDF Assistant",
page_icon="πŸ“„",
layout="wide",
initial_sidebar_state="expanded"
)
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
if 'current_file' not in st.session_state:
st.session_state.current_file = None
st.title("πŸ“„ Smart PDF Analyzer")
st.markdown("""
<div style="border-left: 4px solid #4CAF50; padding-left: 1rem; margin: 1rem 0;">
<p style="color: #666; font-size: 0.95rem;">✨ Upload a PDF to:
<ul style="color: #666; font-size: 0.95rem;">
<li>Generate structured summaries</li>
<li>Extract visual content</li>
<li>Ask contextual questions</li>
</ul>
</p>
</div>
""", unsafe_allow_html=True)
uploaded_file = st.file_uploader(
"Choose PDF file",
type="pdf",
help="Max file size: 50MB",
on_change=lambda: setattr(st.session_state, 'chat_history', [])
)
if uploaded_file and uploaded_file.size > MAX_FILE_SIZE:
st.error("File size exceeds 50MB limit")
st.stop()
if uploaded_file:
file_path = tempfile.NamedTemporaryFile(delete=False).name
with open(file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
# Let the user choose whether to include in-text citations in the summary
include_citations = st.checkbox("Include in-text citations in summary", value=True)
chat_container = st.container()
with chat_container:
for idx, chat in enumerate(st.session_state.chat_history):
col1, col2 = st.columns([1, 4])
if chat.get("user"):
with col2:
message(chat["user"], is_user=True, key=f"user_{idx}")
if chat.get("bot"):
with col1:
message(chat["bot"], key=f"bot_{idx}", allow_html=True)
scroll_to_bottom()
with st.container():
col1, col2, col3 = st.columns([3, 2, 2])
with col1:
user_input = st.chat_input("Ask about the document...")
with col2:
if st.button("πŸ“ Generate Summary", use_container_width=True):
with st.spinner("Analyzing document structure..."):
show_progress("Generating summary")
summary = summarize_pdf_with_tooltips(file_path)
st.session_state.chat_history.append({
"user": "Summary request",
"bot": f"## Document Summary\n{summary}"
})
st.rerun()
with col3:
if st.button("πŸ–ΌοΈ Extract Visuals", use_container_width=True):
with st.spinner("Identifying figures and tables..."):
show_progress("Extracting visuals")
figures, tables = process_pdf(file_path)
if figures:
st.session_state.chat_history.append({
"bot": f"Found {len(figures)} figures:"
})
for fig in figures:
st.session_state.chat_history.append({
"bot": f'<img src="data:image/jpeg;base64,{image_to_base64(fig)}" style="max-width: 100%;">'
})
if tables:
st.session_state.chat_history.append({
"bot": f"Found {len(tables)} tables:"
})
for tab in tables:
st.session_state.chat_history.append({
"bot": f'<img src="data:image/jpeg;base64,{image_to_base64(tab)}" style="max-width: 100%;">'
})
st.rerun()
if user_input:
st.session_state.chat_history.append({"user": user_input})
with st.spinner("Analyzing query..."):
show_progress("Generating answer")
answer = qa_pdf(file_path, user_input)
st.session_state.chat_history[-1]["bot"] = f"## Answer\n{answer}"
st.rerun()
st.markdown("""
<style>
.stChatMessage {
padding: 1.25rem;
margin: 1rem 0;
border-radius: 12px;
box-shadow: 0 2px 8px rgba(0,0,0,0.1);
transition: transform 0.2s ease;
}
.stChatMessage:hover {
transform: translateY(-2px);
}
.stButton>button {
background: linear-gradient(45deg, #4CAF50, #45a049);
color: white;
border: none;
border-radius: 8px;
padding: 12px 24px;
font-size: 16px;
transition: all 0.3s ease;
}
.stButton>button:hover {
box-shadow: 0 4px 12px rgba(76,175,80,0.3);
transform: translateY(-1px);
}
[data-testid="stFileUploader"] {
border: 2px dashed #4CAF50;
border-radius: 12px;
padding: 2rem;
}
.tooltip {
position: relative;
cursor: pointer;
border-bottom: 1px dotted #555;
}
/* Tooltip text */
.tooltip:hover::after {
content: attr(data-tooltip);
position: absolute;
left: 0;
top: 1.5em;
background: #333;
color: #fff;
padding: 5px 10px;
border-radius: 5px;
white-space: pre-wrap;
z-index: 100;
width: 300px; /* Adjust width as needed */
}
</style>
""", unsafe_allow_html=True)