Spaces:
Runtime error
Runtime error
import streamlit as st | |
import logging | |
import os | |
import tempfile | |
import shutil | |
import pdfplumber | |
import ollama | |
import time | |
import httpx | |
from langchain_community.document_loaders import UnstructuredPDFLoader | |
from langchain_community.embeddings import OllamaEmbeddings | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores import Chroma | |
from langchain.prompts import ChatPromptTemplate, PromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_community.chat_models import ChatOllama | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain.retrievers.multi_query import MultiQueryRetriever | |
from typing import List, Tuple, Dict, Any, Optional | |
# Streamlit page configuration | |
st.set_page_config( | |
page_title="Ollama PDF RAG Streamlit UI", | |
page_icon="π", | |
layout="wide", | |
initial_sidebar_state="collapsed", | |
) | |
# Logging configuration | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(levelname)s - %(message)s", | |
datefmt="%Y-%m-%d %H:%M:%S", | |
) | |
logger = logging.getLogger(__name__) | |
def ollama_list_with_retry(retries=3, delay=5): | |
"""Attempt to list models from Ollama with retry logic.""" | |
for attempt in range(retries): | |
try: | |
response = ollama.list() | |
logger.info("Successfully retrieved model list from Ollama") | |
return response | |
except httpx.ConnectError as e: | |
logger.error(f"Connection error: {e}. Attempt {attempt + 1} of {retries}") | |
if attempt < retries - 1: | |
time.sleep(delay) | |
else: | |
logger.error("All retry attempts failed. Cannot connect to Ollama service.") | |
raise | |
def extract_model_names(models_info: Dict[str, List[Dict[str, Any]]]) -> Tuple[str, ...]: | |
"""Extract model names from the provided models information.""" | |
logger.info("Extracting model names from models_info") | |
model_names = tuple(model["name"] for model in models_info["models"]) | |
logger.info(f"Extracted model names: {model_names}") | |
return model_names | |
def create_vector_db(file_upload) -> Chroma: | |
"""Create a vector database from an uploaded PDF file.""" | |
logger.info(f"Creating vector DB from file upload: {file_upload.name}") | |
temp_dir = tempfile.mkdtemp() | |
path = os.path.join(temp_dir, file_upload.name) | |
with open(path, "wb") as f: | |
f.write(file_upload.getvalue()) | |
logger.info(f"File saved to temporary path: {path}") | |
loader = UnstructuredPDFLoader(path) | |
data = loader.load() | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=7500, chunk_overlap=100) | |
chunks = text_splitter.split_documents(data) | |
logger.info("Document split into chunks") | |
embeddings = OllamaEmbeddings(model="nomic-embed-text", show_progress=True) | |
vector_db = Chroma.from_documents( | |
documents=chunks, embedding=embeddings, collection_name="myRAG" | |
) | |
logger.info("Vector DB created") | |
shutil.rmtree(temp_dir) | |
logger.info(f"Temporary directory {temp_dir} removed") | |
return vector_db | |
def process_question(question: str, vector_db: Chroma, selected_model: str) -> str: | |
"""Process a user question using the vector database and selected language model.""" | |
logger.info(f"Processing question: {question} using model: {selected_model}") | |
llm = ChatOllama(model=selected_model, temperature=0) | |
QUERY_PROMPT = PromptTemplate( | |
input_variables=["question"], | |
template="""You are an AI language model assistant. Your task is to generate 3 | |
different versions of the given user question to retrieve relevant documents from | |
a vector database. By generating multiple perspectives on the user question, your | |
goal is to help the user overcome some of the limitations of the distance-based | |
similarity search. Provide these alternative questions separated by newlines. | |
Original question: {question}""", | |
) | |
retriever = MultiQueryRetriever.from_llm( | |
vector_db.as_retriever(), llm, prompt=QUERY_PROMPT | |
) | |
template = """Answer the question based ONLY on the following context: | |
{context} | |
Question: {question} | |
If you don't know the answer, just say that you don't know, don't try to make up an answer. | |
Only provide the answer from the {context}, nothing else. | |
Add snippets of the context you used to answer the question. | |
""" | |
prompt = ChatPromptTemplate.from_template(template) | |
chain = ( | |
{"context": retriever, "question": RunnablePassthrough()} | |
| prompt | |
| llm | |
| StrOutputParser() | |
) | |
response = chain.invoke(question) | |
logger.info("Question processed and response generated") | |
return response | |
def extract_all_pages_as_images(file_upload) -> List[Any]: | |
"""Extract all pages from a PDF file as images.""" | |
logger.info(f"Extracting all pages as images from file: {file_upload.name}") | |
pdf_pages = [] | |
with pdfplumber.open(file_upload) as pdf: | |
pdf_pages = [page.to_image().original for page in pdf.pages] | |
logger.info("PDF pages extracted as images") | |
return pdf_pages | |
def delete_vector_db(vector_db: Optional[Chroma]) -> None: | |
"""Delete the vector database and clear related session state.""" | |
logger.info("Deleting vector DB") | |
if vector_db is not None: | |
vector_db.delete_collection() | |
st.session_state.pop("pdf_pages", None) | |
st.session_state.pop("file_upload", None) | |
st.session_state.pop("vector_db", None) | |
st.success("Collection and temporary files deleted successfully.") | |
logger.info("Vector DB and related session state cleared") | |
st.rerun() | |
else: | |
st.error("No vector database found to delete.") | |
logger.warning("Attempted to delete vector DB, but none was found") | |
def main() -> None: | |
"""Main function to run the Streamlit application.""" | |
st.subheader("π§ Ollama PDF RAG playground", divider="gray", anchor=False) | |
try: | |
models_info = ollama_list_with_retry() | |
available_models = extract_model_names(models_info) | |
except httpx.ConnectError: | |
st.error("Could not connect to the Ollama service. Please check your setup and try again.") | |
return | |
col1, col2 = st.columns([1.5, 2]) | |
if "messages" not in st.session_state: | |
st.session_state["messages"] = [] | |
if "vector_db" not in st.session_state: | |
st.session_state["vector_db"] | |