documind-api-v2 / document_rag_router.py
pvanand's picture
Upload 7 files
5d42805 verified
raw
history blame
13.6 kB
from fastapi import UploadFile, File, Form, HTTPException, APIRouter
from typing import List, Optional, Dict, Tuple
import lancedb
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import get_registry
import pandas as pd
from utils import process_pdf_to_chunks
import hashlib
import uuid
import json
from datetime import datetime
from pydantic import BaseModel
import logging
# Create router
router = APIRouter(
prefix="/rag",
tags=["rag"]
)
# Initialize LanceDB and embedding model
db = lancedb.connect("/tmp/db")
model = get_registry().get("sentence-transformers").create(
name="Snowflake/snowflake-arctic-embed-xs",
device="cpu"
)
def get_user_collection(user_id: str, collection_name: str) -> str:
"""Generate user-specific collection name"""
return f"{user_id}_{collection_name}"
class DocumentChunk(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
document_id: str
chunk_index: int
file_name: str
file_type: str
created_date: str
collection_id: str
user_id: str
metadata_json: str
char_start: int
char_end: int
page_numbers: List[int]
images: List[str]
class QueryInput(BaseModel):
collection_id: str
query: str
top_k: Optional[int] = 3
user_id: str
class SearchResult(BaseModel):
text: str
distance: float
metadata: Dict # Added metadata field
class SearchResponse(BaseModel):
results: List[SearchResult]
async def process_file(file: UploadFile, collection_id: str, user_id: str) -> Tuple[List[dict], str]:
"""Process single file and return chunks with metadata"""
content = await file.read()
file_type = file.filename.split('.')[-1].lower()
chunks = []
doc_id = ""
if file_type == 'pdf':
chunks, doc_id = process_pdf_to_chunks(
pdf_content=content,
file_name=file.filename
)
elif file_type == 'txt':
doc_id = hashlib.sha256(content).hexdigest()[:4]
text_content = content.decode('utf-8')
chunks = [{
"text": text_content,
"metadata": {
"created_date": datetime.now().isoformat(),
"file_name": file.filename,
"document_id": doc_id,
"user_id": user_id,
"location": {
"chunk_index": 0,
"char_start": 0,
"char_end": len(text_content),
"pages": [1],
"total_chunks": 1
},
"images": []
}
}]
return chunks, doc_id
@router.post("/upload_files")
async def upload_files(
files: List[UploadFile] = File(...),
collection_name: Optional[str] = Form(None),
user_id: str = Form(...)
):
try:
collection_id = get_user_collection(
user_id,
collection_name if collection_name else f"col_{uuid.uuid4().hex[:8]}"
)
all_chunks = []
doc_ids = {}
for file in files:
try:
chunks, doc_id = await process_file(file, collection_id, user_id)
for chunk in chunks:
chunk_data = {
"text": chunk["text"],
"document_id": chunk["metadata"]["document_id"],
"chunk_index": chunk["metadata"]["location"]["chunk_index"],
"file_name": chunk["metadata"]["file_name"],
"file_type": file.filename.split('.')[-1].lower(),
"created_date": chunk["metadata"]["created_date"],
"collection_id": collection_id,
"user_id": user_id,
"metadata_json": json.dumps(chunk["metadata"]),
"char_start": chunk["metadata"]["location"]["char_start"],
"char_end": chunk["metadata"]["location"]["char_end"],
"page_numbers": chunk["metadata"]["location"]["pages"],
"images": chunk["metadata"].get("images", [])
}
all_chunks.append(chunk_data)
doc_ids[doc_id] = file.filename
except Exception as e:
logging.error(f"Error processing file {file.filename}: {str(e)}")
raise HTTPException(
status_code=400,
detail=f"Error processing file {file.filename}: {str(e)}"
)
try:
table = db.open_table(collection_id)
except Exception as e:
logging.error(f"Error opening table: {str(e)}")
try:
table = db.create_table(
collection_id,
schema=DocumentChunk,
mode="create"
)
# Create FTS index on the text column for hybrid search support
# table.create_fts_index(
# field_names="text",
# replace=True,
# tokenizer_name="en_stem", # Use English stemming
# lower_case=True, # Convert text to lowercase
# remove_stop_words=True, # Remove common words like "the", "is", "at"
# writer_heap_size=1024 * 1024 * 1024 # 1GB heap size
# )
except Exception as e:
logging.error(f"Error creating table: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Error creating database table: {str(e)}"
)
try:
df = pd.DataFrame(all_chunks)
table.add(data=df)
except Exception as e:
logging.error(f"Error adding data to table: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Error adding data to database: {str(e)}"
)
return {
"message": f"Successfully processed {len(files)} files",
"collection_id": collection_id,
"total_chunks": len(all_chunks),
"user_id": user_id,
"document_ids": doc_ids
}
except HTTPException:
raise
except Exception as e:
logging.error(f"Unexpected error during file upload: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Unexpected error: {str(e)}"
)
@router.get("/get_document/{collection_id}/{document_id}")
async def get_document(
collection_id: str,
document_id: str,
user_id: str
):
try:
table = db.open_table(f"{user_id}_{collection_id}")
except Exception as e:
logging.error(f"Error opening table: {str(e)}")
raise HTTPException(
status_code=404,
detail=f"Collection not found: {str(e)}"
)
try:
chunks = table.to_pandas()
doc_chunks = chunks[
(chunks['document_id'] == document_id) &
(chunks['user_id'] == user_id)
].sort_values('chunk_index')
if len(doc_chunks) == 0:
raise HTTPException(
status_code=404,
detail=f"Document {document_id} not found in collection {collection_id}"
)
return {
"document_id": document_id,
"file_name": doc_chunks.iloc[0]['file_name'],
"chunks": [
{
"text": row['text'],
"metadata": json.loads(row['metadata_json'])
}
for _, row in doc_chunks.iterrows()
]
}
except HTTPException:
raise
except Exception as e:
logging.error(f"Error retrieving document: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Error retrieving document: {str(e)}"
)
@router.post("/query_collection", response_model=SearchResponse)
async def query_collection(input_data: QueryInput):
try:
collection_id = get_user_collection(input_data.user_id, input_data.collection_id)
try:
table = db.open_table(collection_id)
except Exception as e:
logging.error(f"Error opening table: {str(e)}")
raise HTTPException(
status_code=404,
detail=f"Collection not found: {str(e)}"
)
try:
results = (
table.search(input_data.query)
.where(f"user_id = '{input_data.user_id}'")
.limit(input_data.top_k)
.to_list()
)
except Exception as e:
logging.error(f"Error searching collection: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Error searching collection: {str(e)}"
)
return SearchResponse(results=[
SearchResult(
text=r['text'],
distance=float(r['_distance']),
metadata=json.loads(r['metadata_json'])
)
for r in results
])
except HTTPException:
raise
except Exception as e:
logging.error(f"Unexpected error during query: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Unexpected error: {str(e)}"
)
@router.get("/list_collections")
async def list_collections(user_id: str):
try:
all_collections = db.table_names()
user_collections = [
c for c in all_collections
if c.startswith(f"{user_id}_")
]
# Get documents for each collection
collections_info = []
for collection_name in user_collections:
try:
table = db.open_table(collection_name)
df = table.to_pandas()
# Group by document_id to get unique documents
documents = df.groupby('document_id').agg({
'file_name': 'first',
'created_date': 'first'
}).reset_index()
collections_info.append({
"collection_id": collection_name.replace(f"{user_id}_", ""),
"documents": [
{
"document_id": row['document_id'],
"file_name": row['file_name'],
"created_date": row['created_date']
}
for _, row in documents.iterrows()
]
})
except Exception as e:
logging.error(f"Error processing collection {collection_name}: {str(e)}")
continue
return {"collections": collections_info}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/delete_collection/{collection_id}")
async def delete_collection(collection_id: str, user_id: str):
try:
full_collection_id = f"{user_id}_{collection_id}"
# Check if collection exists
try:
table = db.open_table(full_collection_id)
except Exception as e:
logging.error(f"Collection not found: {str(e)}")
raise HTTPException(
status_code=404,
detail=f"Collection {collection_id} not found"
)
# Verify ownership
if not full_collection_id.startswith(f"{user_id}_"):
logging.error(f"Unauthorized deletion attempt for collection {collection_id} by user {user_id}")
raise HTTPException(
status_code=403,
detail="Not authorized to delete this collection"
)
try:
db.drop_table(full_collection_id)
except Exception as e:
logging.error(f"Error deleting collection {collection_id}: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Error deleting collection: {str(e)}"
)
return {
"message": f"Collection {collection_id} deleted successfully",
"collection_id": collection_id
}
except HTTPException:
raise
except Exception as e:
logging.error(f"Unexpected error deleting collection {collection_id}: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Unexpected error: {str(e)}"
)
@router.post("/query_collection_tool")
async def query_collection_tool(input_data: QueryInput):
try:
response = await query_collection(input_data)
results = []
# Access response directly since it's a Pydantic model
for r in response.results:
result_dict = {
"text": r.text,
"distance": r.distance,
"metadata": {
"document_id": r.metadata.get("document_id"),
"chunk_index": r.metadata.get("location", {}).get("chunk_index")
}
}
results.append(result_dict)
return str(results)
except Exception as e:
logging.error(f"Unexpected error during query: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Unexpected error: {str(e)}"
)