documind-api-v2 / document_rag_router.py
pvanand's picture
Update document_rag_router.py
c182518 verified
from fastapi import UploadFile, File, Form, HTTPException, APIRouter
from typing import List, Optional, Dict, Tuple
import lancedb
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import get_registry
import pandas as pd
from utils import process_pdf_to_chunks
import hashlib
import uuid
import json
from datetime import datetime
from pydantic import BaseModel
import logging
# Create router
router = APIRouter(
prefix="/rag",
tags=["rag"]
)
# Initialize LanceDB and embedding model
db = lancedb.connect("/tmp/db")
model = get_registry().get("sentence-transformers").create(
name="Snowflake/snowflake-arctic-embed-xs",
device="cpu"
)
def get_user_collection(user_id: str, collection_name: str) -> str:
"""Generate user-specific collection name"""
return f"{user_id}_{collection_name}"
class DocumentChunk(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
document_id: str
chunk_index: int
file_name: str
file_type: str
created_date: str
collection_id: str
user_id: str
metadata_json: str
char_start: int
char_end: int
page_numbers: List[int]
images: List[str]
class QueryInput(BaseModel):
collection_id: str
query: str
top_k: Optional[int] = 3
user_id: str
class SearchResult(BaseModel):
text: str
distance: float
metadata: Dict # Added metadata field
class SearchResponse(BaseModel):
results: List[SearchResult]
async def process_file(file: UploadFile, collection_id: str, user_id: str) -> Tuple[List[dict], str]:
"""Process single file and return chunks with metadata"""
content = await file.read()
file_type = file.filename.split('.')[-1].lower()
chunks = []
doc_id = ""
if file_type == 'pdf':
chunks, doc_id = process_pdf_to_chunks(
pdf_content=content,
file_name=file.filename
)
elif file_type == 'txt':
doc_id = hashlib.sha256(content).hexdigest()[:4]
text_content = content.decode('utf-8')
chunks = [{
"text": text_content,
"metadata": {
"created_date": datetime.now().isoformat(),
"file_name": file.filename,
"document_id": doc_id,
"user_id": user_id,
"location": {
"chunk_index": 0,
"char_start": 0,
"char_end": len(text_content),
"pages": [1],
"total_chunks": 1
},
"images": []
}
}]
return chunks, doc_id
@router.post("/upload_files")
async def upload_files(
files: List[UploadFile] = File(...),
collection_name: Optional[str] = Form(None),
user_id: str = Form(...)
):
try:
collection_id = get_user_collection(
user_id,
collection_name if collection_name else f"col_{uuid.uuid4().hex[:8]}"
)
all_chunks = []
doc_ids = {}
for file in files:
try:
chunks, doc_id = await process_file(file, collection_id, user_id)
for chunk in chunks:
chunk_data = {
"text": chunk["text"],
"document_id": chunk["metadata"]["document_id"],
"chunk_index": chunk["metadata"]["location"]["chunk_index"],
"file_name": chunk["metadata"]["file_name"],
"file_type": file.filename.split('.')[-1].lower(),
"created_date": chunk["metadata"]["created_date"],
"collection_id": collection_id,
"user_id": user_id,
"metadata_json": json.dumps(chunk["metadata"]),
"char_start": chunk["metadata"]["location"]["char_start"],
"char_end": chunk["metadata"]["location"]["char_end"],
"page_numbers": chunk["metadata"]["location"]["pages"],
"images": chunk["metadata"].get("images", [])
}
all_chunks.append(chunk_data)
doc_ids[doc_id] = file.filename
except Exception as e:
logging.error(f"Error processing file {file.filename}: {str(e)}")
raise HTTPException(
status_code=400,
detail=f"Error processing file {file.filename}: {str(e)}"
)
try:
table = db.open_table(collection_id)
except Exception as e:
logging.error(f"Error opening table: {str(e)}")
try:
table = db.create_table(
collection_id,
schema=DocumentChunk,
mode="create"
)
# Create FTS index on the text column for hybrid search support
# table.create_fts_index(
# field_names="text",
# replace=True,
# tokenizer_name="en_stem", # Use English stemming
# lower_case=True, # Convert text to lowercase
# remove_stop_words=True, # Remove common words like "the", "is", "at"
# writer_heap_size=1024 * 1024 * 1024 # 1GB heap size
# )
except Exception as e:
logging.error(f"Error creating table: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Error creating database table: {str(e)}"
)
try:
df = pd.DataFrame(all_chunks)
table.add(data=df)
except Exception as e:
logging.error(f"Error adding data to table: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Error adding data to database: {str(e)}"
)
return {
"message": f"Successfully processed {len(files)} files",
"collection_id": collection_id,
"total_chunks": len(all_chunks),
"user_id": user_id,
"document_ids": doc_ids
}
except HTTPException:
raise
except Exception as e:
logging.error(f"Unexpected error during file upload: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Unexpected error: {str(e)}"
)
@router.get("/get_document/{collection_id}/{document_id}")
async def get_document(
collection_id: str,
document_id: str,
user_id: str
):
try:
table = db.open_table(f"{user_id}_{collection_id}")
except Exception as e:
logging.error(f"Error opening table: {str(e)}")
raise HTTPException(
status_code=404,
detail=f"Collection not found: {str(e)}"
)
try:
chunks = table.to_pandas()
doc_chunks = chunks[
(chunks['document_id'] == document_id) &
(chunks['user_id'] == user_id)
].sort_values('chunk_index')
if len(doc_chunks) == 0:
raise HTTPException(
status_code=404,
detail=f"Document {document_id} not found in collection {collection_id}"
)
return {
"document_id": document_id,
"file_name": doc_chunks.iloc[0]['file_name'],
"chunks": [
{
"text": row['text'],
"metadata": json.loads(row['metadata_json'])
}
for _, row in doc_chunks.iterrows()
]
}
except HTTPException:
raise
except Exception as e:
logging.error(f"Error retrieving document: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Error retrieving document: {str(e)}"
)
@router.post("/query_collection", response_model=SearchResponse)
async def query_collection(input_data: QueryInput):
try:
collection_id = get_user_collection(input_data.user_id, input_data.collection_id)
try:
table = db.open_table(collection_id)
except Exception as e:
logging.error(f"Error opening table: {str(e)}")
raise HTTPException(
status_code=404,
detail=f"Collection not found: {str(e)}"
)
try:
results = (
table.search(input_data.query)
.where(f"user_id = '{input_data.user_id}'")
.limit(input_data.top_k)
.to_list()
)
except Exception as e:
logging.error(f"Error searching collection: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Error searching collection: {str(e)}"
)
return SearchResponse(results=[
SearchResult(
text=r['text'],
distance=float(r['_distance']),
metadata=json.loads(r['metadata_json'])
)
for r in results
])
except HTTPException:
raise
except Exception as e:
logging.error(f"Unexpected error during query: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Unexpected error: {str(e)}"
)
@router.get("/list_collections")
async def list_collections(user_id: str):
try:
all_collections = db.table_names()
user_collections = [
c for c in all_collections
if c.startswith(f"{user_id}_")
]
# Get documents for each collection
collections_info = []
for collection_name in user_collections:
try:
table = db.open_table(collection_name)
df = table.to_pandas()
# Group by document_id to get unique documents
documents = df.groupby('document_id').agg({
'file_name': 'first',
'created_date': 'first'
}).reset_index()
collections_info.append({
"collection_id": collection_name.replace(f"{user_id}_", ""),
"documents": [
{
"document_id": row['document_id'],
"file_name": row['file_name'],
"created_date": row['created_date']
}
for _, row in documents.iterrows()
]
})
except Exception as e:
logging.error(f"Error processing collection {collection_name}: {str(e)}")
continue
return {"collections": collections_info}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/delete_collection/{collection_id}")
async def delete_collection(collection_id: str, user_id: str):
try:
full_collection_id = f"{user_id}_{collection_id}"
# Check if collection exists
try:
table = db.open_table(full_collection_id)
except Exception as e:
logging.error(f"Collection not found: {str(e)}")
raise HTTPException(
status_code=404,
detail=f"Collection {collection_id} not found"
)
# Verify ownership
if not full_collection_id.startswith(f"{user_id}_"):
logging.error(f"Unauthorized deletion attempt for collection {collection_id} by user {user_id}")
raise HTTPException(
status_code=403,
detail="Not authorized to delete this collection"
)
try:
db.drop_table(full_collection_id)
except Exception as e:
logging.error(f"Error deleting collection {collection_id}: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Error deleting collection: {str(e)}"
)
return {
"message": f"Collection {collection_id} deleted successfully",
"collection_id": collection_id
}
except HTTPException:
raise
except Exception as e:
logging.error(f"Unexpected error deleting collection {collection_id}: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Unexpected error: {str(e)}"
)
@router.post("/get_collection_files")
def get_collection_files(collection_id: str, user_id: str) -> str:
"""Get list of files in the specified collection"""
try:
# Get the full collection name
collection_name = f"{user_id}_{collection_id}"
# Open the table and convert to pandas
table = db.open_table(collection_name)
df = table.to_pandas()
logging.info(f"fetched chunks {str(df.head())}")
# Get unique file names
unique_files = df['file_name'].unique()
# Join the file names into a string
return ", ".join(unique_files)
except Exception as e:
logging.error(f"Error getting collection files: {str(e)}")
return f"Error getting files: {str(e)}"
@router.post("/query_collection_tool")
async def query_collection_tool(input_data: QueryInput):
try:
response = await query_collection(input_data)
results = []
# Access response directly since it's a Pydantic model
for r in response.results:
result_dict = {
"text": r.text,
"distance": r.distance,
"metadata": {
"document_id": r.metadata.get("document_id"),
"chunk_index": r.metadata.get("location", {}).get("chunk_index")
}
}
results.append(result_dict)
return str(results)
except Exception as e:
logging.error(f"Unexpected error during query: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Unexpected error: {str(e)}"
)