Spaces:
Sleeping
Sleeping
from dataclasses import dataclass | |
from operator import itemgetter | |
from pathlib import Path | |
from typing import List, Optional, Dict, Any | |
import logging | |
from enum import Enum | |
import gradio as gr | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores import Chroma | |
from langchain.prompts import PromptTemplate | |
from langchain.schema import BaseRetriever | |
from langchain.embeddings.base import Embeddings | |
from langchain.llms.base import BaseLanguageModel | |
import PyPDF2 | |
# Install required packages | |
# Initialize models | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings | |
from langchain_community.llms import HuggingFacePipeline | |
from transformers import pipeline | |
from sentence_transformers import SentenceTransformer | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
embed_model = HuggingFaceBgeEmbeddings( | |
model_name="all-MiniLM-L6-v2",#"dunzhang/stella_en_1.5B_v5", | |
model_kwargs={'device': 'cpu'}, | |
encode_kwargs={'normalize_embeddings': True} | |
) | |
model_name = "meta-llama/Llama-3.2-3B-Instruct" #"google/gemma-2-2b-it"#"prithivMLmods/Llama-3.2-3B-GGUF" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
trust_remote_code=True, | |
use_auth_token=True | |
) | |
# model.generation_config.pad_token_id = model.generation_config.eos_token_id | |
# embed_model = embedding_model | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class DocumentFormat(Enum): | |
PDF = ".pdf" | |
# Can be extended for other document types | |
class RAGConfig: | |
"""Configuration for RAG system parameters""" | |
chunk_size: int = 500 | |
chunk_overlap: int = 100 | |
retriever_k: int = 3 | |
persist_directory: str = "./chroma_db" | |
class AdvancedRAGSystem: | |
"""Advanced RAG System with improved error handling and type safety""" | |
DEFAULT_TEMPLATE = """<|start_header_id|>system<|end_header_id|> | |
You are a helpful assistant. Use the following pieces of context to answer the question at the end. | |
If you don't know the answer, just say that you don't know, don't try to make up an answer. | |
Context: | |
{context} | |
<|eot_id|><|start_header_id|>user<|end_header_id|> | |
{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|> | |
""" | |
def __init__( | |
self, | |
embed_model: Embeddings, | |
llm: BaseLanguageModel, | |
config: Optional[RAGConfig] = None | |
): | |
"""Initialize the RAG system with required models and optional configuration""" | |
self.embed_model = embed_model | |
self.llm = llm | |
self.config = config or RAGConfig() | |
self.vector_store: Optional[Chroma] = None | |
self.last_context: Optional[str] = None | |
self.prompt = PromptTemplate( | |
template=self.DEFAULT_TEMPLATE, | |
input_variables=["context", "question"] | |
) | |
def _validate_file(self, file_path: Path) -> bool: | |
"""Validate if the file is of supported format and exists""" | |
return file_path.suffix.lower() == DocumentFormat.PDF.value and file_path.exists() | |
def _extract_text_from_pdf(self, pdf_path: Path) -> str: | |
"""Extract text from a PDF file with proper error handling""" | |
try: | |
with open(pdf_path, 'rb') as file: | |
pdf_reader = PyPDF2.PdfReader(file) | |
return "\n".join( | |
page.extract_text() | |
for page in pdf_reader.pages | |
) | |
except Exception as e: | |
logger.error(f"Error processing PDF {pdf_path}: {str(e)}") | |
raise ValueError(f"Failed to process PDF {pdf_path}: {str(e)}") | |
def _create_document_chunks(self, texts: List[str]) -> List[Any]: | |
"""Split documents into chunks using the configured parameters""" | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=self.config.chunk_size, | |
chunk_overlap=self.config.chunk_overlap, | |
length_function=len, | |
add_start_index=True, | |
) | |
return text_splitter.create_documents(texts) | |
def process_pdfs(self, pdf_files: List[str]) -> str: | |
"""Process and index PDF documents with improved error handling""" | |
try: | |
# Convert to Path objects and validate | |
pdf_paths = [Path(pdf.name) for pdf in pdf_files] | |
invalid_files = [f for f in pdf_paths if not self._validate_file(f)] | |
if invalid_files: | |
raise ValueError(f"Invalid or missing files: {invalid_files}") | |
# Extract text from valid PDFs | |
documents = [ | |
self._extract_text_from_pdf(pdf_path) | |
for pdf_path in pdf_paths | |
] | |
# Create document chunks | |
doc_chunks = self._create_document_chunks(documents) | |
# Initialize or update vector store | |
self.vector_store = Chroma.from_documents( | |
documents=doc_chunks, | |
embedding=self.embed_model, | |
persist_directory=self.config.persist_directory | |
) | |
logger.info(f"Successfully processed {len(doc_chunks)} chunks from {len(pdf_files)} PDF files") | |
return f"Successfully processed {len(doc_chunks)} chunks from {len(pdf_files)} PDF files" | |
except Exception as e: | |
error_msg = f"Error during PDF processing: {str(e)}" | |
logger.error(error_msg) | |
raise RuntimeError(error_msg) | |
def get_retriever(self) -> BaseRetriever: | |
"""Get the document retriever with current configuration""" | |
if not self.vector_store: | |
raise RuntimeError("Vector store not initialized. Please process documents first.") | |
return self.vector_store.as_retriever(search_kwargs={"k": self.config.retriever_k}) | |
def _format_context(self, documents: List[Any]) -> str: | |
"""Format retrieved documents into a single context string""" | |
return "\n\n".join(doc.page_content for doc in documents) | |
def query(self, question: str) -> Dict[str, str]: | |
"""Query the RAG system with improved error handling and response formatting""" | |
try: | |
if not self.vector_store: | |
raise RuntimeError("Please process PDF documents first before querying") | |
# Retrieve relevant documents | |
retriever = self.get_retriever() | |
retrieved_docs = retriever.get_relevant_documents(question) | |
context = self._format_context(retrieved_docs) | |
self.last_context = context | |
# Generate response using LLM | |
response = self.llm.invoke( | |
self.prompt.format( | |
context=context, | |
question=question | |
) | |
) | |
return { | |
"answer": response.split("<|end_header_id|>")[-1], | |
"context": context, | |
"source_documents": len(retrieved_docs) | |
} | |
except Exception as e: | |
error_msg = f"Error during query processing: {str(e)}" | |
logger.error(error_msg) | |
raise RuntimeError(error_msg) | |
def create_gradio_interface(rag_system: AdvancedRAGSystem) -> gr.Blocks: | |
"""Create an improved Gradio interface for the RAG system""" | |
def process_files(files: List[Any], chunk_size: int, overlap: int) -> str: | |
"""Process uploaded files with updated configuration""" | |
if not files: | |
return "Please upload PDF files" | |
# Update configuration with new parameters | |
rag_system.config.chunk_size = chunk_size | |
rag_system.config.chunk_overlap = overlap | |
try: | |
return rag_system.process_pdfs(files) | |
except Exception as e: | |
return f"Error: {str(e)}" | |
def query_and_update_history(question: str) -> tuple[str, str]: | |
"""Query system and update history with error handling""" | |
try: | |
result = rag_system.query(question) | |
return ( | |
result["answer"], | |
f"Last context used ({result['source_documents']} documents):\n\n{result['context']}" | |
) | |
except Exception as e: | |
return str(e), "Error occurred while retrieving context" | |
with gr.Blocks(title="Advanced RAG System") as demo: | |
gr.Markdown("# Advanced RAG System with PDF Processing") | |
with gr.Tab("Upload & Process PDFs"): | |
with gr.Row(): | |
with gr.Column(): | |
file_input = gr.File( | |
file_count="multiple", | |
label="Upload PDF Documents", | |
file_types=[".pdf"] | |
) | |
chunk_size = gr.Slider( | |
minimum=100, | |
maximum=10000, | |
value=500, | |
step=100, | |
label="Chunk Size" | |
) | |
overlap = gr.Slider( | |
minimum=10, | |
maximum=5000, | |
value=100, | |
step=10, | |
label="Chunk Overlap" | |
) | |
process_button = gr.Button("Process PDFs", variant="primary") | |
process_output = gr.Textbox(label="Processing Status") | |
with gr.Tab("Query System"): | |
with gr.Row(): | |
with gr.Column(scale=2): | |
question_input = gr.Textbox( | |
label="Your Question", | |
placeholder="Enter your question here...", | |
lines=3 | |
) | |
query_button = gr.Button("Get Answer", variant="primary") | |
answer_output = gr.Textbox( | |
label="Answer", | |
lines=10 | |
) | |
with gr.Column(scale=1): | |
history_output = gr.Textbox( | |
label="Retrieved Context", | |
lines=15 | |
) | |
# Set up event handlers | |
process_button.click( | |
fn=process_files, | |
inputs=[file_input, chunk_size, overlap], | |
outputs=[process_output] | |
) | |
query_button.click( | |
fn=query_and_update_history, | |
inputs=[question_input], | |
outputs=[answer_output, history_output] | |
) | |
return demo | |
""" | |
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface | |
""" | |
# demo = gr.ChatInterface( | |
# respond, | |
# additional_inputs=[ | |
# gr.Textbox(value="You are a friendly Chatbot.", label="System message"), | |
# gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), | |
# gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), | |
# gr.Slider( | |
# minimum=0.1, | |
# maximum=1.0, | |
# value=0.95, | |
# step=0.05, | |
# label="Top-p (nucleus sampling)", | |
# ), | |
# ], | |
# ) | |
rag_system = AdvancedRAGSystem(embed_model, llm) | |
demo = create_gradio_interface(rag_system) | |
if __name__ == "__main__": | |
demo.launch() | |