LLMchatbotpdf / app.py
shallou's picture
Update app.py
e6aa251 verified
import streamlit as st
import logging
import os
import tempfile
import shutil
import pdfplumber
import ollama
import time
import httpx
from langchain_community.document_loaders import UnstructuredPDFLoader
from langchain_community.embeddings import OllamaEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_community.chat_models import ChatOllama
from langchain_core.runnables import RunnablePassthrough
from langchain.retrievers.multi_query import MultiQueryRetriever
from typing import List, Tuple, Dict, Any, Optional
# Streamlit page configuration
st.set_page_config(
page_title="Ollama PDF RAG Streamlit UI",
page_icon="🎈",
layout="wide",
initial_sidebar_state="collapsed",
)
# Logging configuration
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)
def ollama_list_with_retry(retries=3, delay=5):
"""Attempt to list models from Ollama with retry logic."""
for attempt in range(retries):
try:
response = ollama.list()
logger.info("Successfully retrieved model list from Ollama")
return response
except httpx.ConnectError as e:
logger.error(f"Connection error: {e}. Attempt {attempt + 1} of {retries}")
if attempt < retries - 1:
time.sleep(delay)
else:
logger.error("All retry attempts failed. Cannot connect to Ollama service.")
raise
@st.cache_resource(show_spinner=True)
def extract_model_names(models_info: Dict[str, List[Dict[str, Any]]]) -> Tuple[str, ...]:
"""Extract model names from the provided models information."""
logger.info("Extracting model names from models_info")
model_names = tuple(model["name"] for model in models_info["models"])
logger.info(f"Extracted model names: {model_names}")
return model_names
def create_vector_db(file_upload) -> Chroma:
"""Create a vector database from an uploaded PDF file."""
logger.info(f"Creating vector DB from file upload: {file_upload.name}")
temp_dir = tempfile.mkdtemp()
path = os.path.join(temp_dir, file_upload.name)
with open(path, "wb") as f:
f.write(file_upload.getvalue())
logger.info(f"File saved to temporary path: {path}")
loader = UnstructuredPDFLoader(path)
data = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=7500, chunk_overlap=100)
chunks = text_splitter.split_documents(data)
logger.info("Document split into chunks")
embeddings = OllamaEmbeddings(model="nomic-embed-text", show_progress=True)
vector_db = Chroma.from_documents(
documents=chunks, embedding=embeddings, collection_name="myRAG"
)
logger.info("Vector DB created")
shutil.rmtree(temp_dir)
logger.info(f"Temporary directory {temp_dir} removed")
return vector_db
def process_question(question: str, vector_db: Chroma, selected_model: str) -> str:
"""Process a user question using the vector database and selected language model."""
logger.info(f"Processing question: {question} using model: {selected_model}")
llm = ChatOllama(model=selected_model, temperature=0)
QUERY_PROMPT = PromptTemplate(
input_variables=["question"],
template="""You are an AI language model assistant. Your task is to generate 3
different versions of the given user question to retrieve relevant documents from
a vector database. By generating multiple perspectives on the user question, your
goal is to help the user overcome some of the limitations of the distance-based
similarity search. Provide these alternative questions separated by newlines.
Original question: {question}""",
)
retriever = MultiQueryRetriever.from_llm(
vector_db.as_retriever(), llm, prompt=QUERY_PROMPT
)
template = """Answer the question based ONLY on the following context:
{context}
Question: {question}
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Only provide the answer from the {context}, nothing else.
Add snippets of the context you used to answer the question.
"""
prompt = ChatPromptTemplate.from_template(template)
chain = (
{"context": retriever, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
response = chain.invoke(question)
logger.info("Question processed and response generated")
return response
@st.cache_data
def extract_all_pages_as_images(file_upload) -> List[Any]:
"""Extract all pages from a PDF file as images."""
logger.info(f"Extracting all pages as images from file: {file_upload.name}")
pdf_pages = []
with pdfplumber.open(file_upload) as pdf:
pdf_pages = [page.to_image().original for page in pdf.pages]
logger.info("PDF pages extracted as images")
return pdf_pages
def delete_vector_db(vector_db: Optional[Chroma]) -> None:
"""Delete the vector database and clear related session state."""
logger.info("Deleting vector DB")
if vector_db is not None:
vector_db.delete_collection()
st.session_state.pop("pdf_pages", None)
st.session_state.pop("file_upload", None)
st.session_state.pop("vector_db", None)
st.success("Collection and temporary files deleted successfully.")
logger.info("Vector DB and related session state cleared")
st.rerun()
else:
st.error("No vector database found to delete.")
logger.warning("Attempted to delete vector DB, but none was found")
def main() -> None:
"""Main function to run the Streamlit application."""
st.subheader("🧠 Ollama PDF RAG playground", divider="gray", anchor=False)
try:
models_info = ollama_list_with_retry()
available_models = extract_model_names(models_info)
except httpx.ConnectError:
st.error("Could not connect to the Ollama service. Please check your setup and try again.")
return
col1, col2 = st.columns([1.5, 2])
if "messages" not in st.session_state:
st.session_state["messages"] = []
if "vector_db" not in st.session_state:
st.session_state["vector_db"]