Spaces:
Runtime error
Runtime error
import os | |
from PyPDF2 import PdfReader | |
from transformers import AutoTokenizer, AutoModel | |
import torch | |
import chromadb | |
from typing import List, Dict | |
import re | |
import numpy as np | |
from pathlib import Path | |
class LegalDocumentProcessor: | |
def __init__(self): | |
print("Initializing Legal Document Processor...") | |
self.tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') | |
self.model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') | |
self.max_chunk_size = 500 # Reduced chunk size | |
self.max_context_length = 4000 # Maximum context length for response | |
# Initialize ChromaDB | |
self.pdf_dir = "/home/user/app" | |
db_dir = os.path.join(self.pdf_dir, "chroma_db") | |
os.makedirs(db_dir, exist_ok=True) | |
print(f"Initializing ChromaDB at {db_dir}") | |
self.chroma_client = chromadb.PersistentClient(path=db_dir) | |
try: | |
self.collection = self.chroma_client.get_collection("indian_legal_docs") | |
print("Found existing collection") | |
except: | |
print("Creating new collection") | |
self.collection = self.chroma_client.create_collection( | |
name="indian_legal_docs", | |
metadata={"description": "Indian Criminal Law Documents"} | |
) | |
def _split_into_chunks(self, text: str) -> List[str]: | |
"""Split text into smaller chunks while preserving context""" | |
# Split on meaningful boundaries | |
patterns = [ | |
r'(?=Chapter \d+)', | |
r'(?=Section \d+)', | |
r'(?=\n\d+\.\s)', # Numbered paragraphs | |
r'\n\n' | |
] | |
# Combine patterns | |
split_pattern = '|'.join(patterns) | |
sections = re.split(split_pattern, text) | |
chunks = [] | |
current_chunk = "" | |
for section in sections: | |
section = section.strip() | |
if not section: | |
continue | |
# If section is small enough, add to current chunk | |
if len(current_chunk) + len(section) < self.max_chunk_size: | |
current_chunk += " " + section | |
else: | |
# If current chunk is not empty, add it to chunks | |
if current_chunk: | |
chunks.append(current_chunk.strip()) | |
# Start new chunk with current section | |
current_chunk = section | |
# Add the last chunk if not empty | |
if current_chunk: | |
chunks.append(current_chunk.strip()) | |
return chunks | |
def process_pdf(self, pdf_path: str) -> List[str]: | |
"""Extract text from PDF and split into chunks""" | |
print(f"Processing PDF: {pdf_path}") | |
try: | |
reader = PdfReader(pdf_path) | |
text = "" | |
for page in reader.pages: | |
text += page.extract_text() + "\n\n" | |
chunks = self._split_into_chunks(text) | |
print(f"Created {len(chunks)} chunks from {pdf_path}") | |
return chunks | |
except Exception as e: | |
print(f"Error processing PDF {pdf_path}: {str(e)}") | |
return [] | |
def get_embedding(self, text: str) -> List[float]: | |
"""Generate embedding for text""" | |
inputs = self.tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt') | |
with torch.no_grad(): | |
model_output = self.model(**inputs) | |
# Mean pooling | |
token_embeddings = model_output[0] | |
attention_mask = inputs['attention_mask'] | |
mask = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
sum_embeddings = torch.sum(token_embeddings * mask, 1) | |
sum_mask = torch.clamp(mask.sum(1), min=1e-9) | |
return (sum_embeddings / sum_mask).squeeze().tolist() | |
def process_and_store_documents(self): | |
"""Process all legal documents and store in ChromaDB""" | |
print("Starting document processing...") | |
# Define the expected PDF paths | |
pdf_files = { | |
'BNS': os.path.join(self.pdf_dir, 'BNS.pdf'), | |
'BNSS': os.path.join(self.pdf_dir, 'BNSS.pdf'), | |
'BSA': os.path.join(self.pdf_dir, 'BSA.pdf') | |
} | |
for law_code, pdf_path in pdf_files.items(): | |
if os.path.exists(pdf_path): | |
print(f"Processing {law_code} from {pdf_path}") | |
chunks = self.process_pdf(pdf_path) | |
if not chunks: | |
print(f"No chunks extracted from {pdf_path}") | |
continue | |
for i, chunk in enumerate(chunks): | |
try: | |
embedding = self.get_embedding(chunk) | |
self.collection.add( | |
documents=[chunk], | |
embeddings=[embedding], | |
metadatas=[{ | |
"law_code": law_code, | |
"chunk_id": f"{law_code}_chunk_{i}", | |
"source": os.path.basename(pdf_path) | |
}], | |
ids=[f"{law_code}_chunk_{i}"] | |
) | |
except Exception as e: | |
print(f"Error processing chunk {i} from {law_code}: {str(e)}") | |
def search_documents(self, query: str, n_results: int = 3) -> Dict: | |
"""Search for relevant legal information""" | |
try: | |
query_embedding = self.get_embedding(query) | |
results = self.collection.query( | |
query_embeddings=[query_embedding], | |
n_results=n_results | |
) | |
# Limit context size | |
documents = results["documents"][0] | |
total_length = 0 | |
filtered_documents = [] | |
filtered_metadatas = [] | |
for doc, metadata in zip(documents, results["metadatas"][0]): | |
doc_length = len(doc) | |
if total_length + doc_length <= self.max_context_length: | |
filtered_documents.append(doc) | |
filtered_metadatas.append(metadata) | |
total_length += doc_length | |
else: | |
break | |
return { | |
"documents": filtered_documents, | |
"metadatas": filtered_metadatas | |
} | |
except Exception as e: | |
print(f"Error during search: {str(e)}") | |
return { | |
"documents": ["Sorry, I couldn't search the documents effectively."], | |
"metadatas": [{"law_code": "ERROR", "source": "error"}] | |
} |