documind-api-v2 / main.py
pvanand's picture
Update main.py
11ba704 verified
#DOCS
# https://langchain-ai.github.io/langgraph/reference/prebuilt/#langgraph.prebuilt.chat_agent_executor.create_react_agent
import uuid
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from langchain_core.messages import (
BaseMessage,
HumanMessage,
SystemMessage,
trim_messages,
)
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import create_react_agent
from pydantic import BaseModel
import json
from typing import Optional, Annotated
from langchain_core.runnables import RunnableConfig
from langgraph.prebuilt import InjectedState
from document_rag_router import router as document_rag_router
from document_rag_router import QueryInput, query_collection, SearchResult,db
from fastapi import HTTPException
import requests
from sse_starlette.sse import EventSourceResponse
from fastapi.middleware.cors import CORSMiddleware
import re
import os
from langchain_core.prompts import ChatPromptTemplate
import logging.config
# Configure logging at application startup
logging.config.dictConfig({
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
"datefmt": "%Y-%m-%d %H:%M:%S",
}
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"stream": "ext://sys.stdout",
"formatter": "default",
"level": "DEBUG",
}
},
"root": {
"level": "DEBUG",
"handlers": ["console"]
},
"loggers": {
"uvicorn": {"handlers": ["console"], "level": "DEBUG"},
"fastapi": {"handlers": ["console"], "level": "DEBUG"}
}
})
# Create logger instance
logger = logging.getLogger(__name__)
app = FastAPI()
app.include_router(document_rag_router)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def get_current_files():
"""Get list of files in current directory"""
try:
files = os.listdir('.')
return ", ".join(files)
except Exception as e:
return f"Error getting files: {str(e)}"
@tool
def get_user_age(name: str) -> str:
"""Use this tool to find the user's age."""
if "bob" in name.lower():
return "42 years old"
return "41 years old"
@tool
async def query_documents(
query: str,
config: RunnableConfig,
) -> str:
"""Use this tool to retrieve relevant data from the collection.
Args:
query: The search query to find relevant document passages
"""
# Get collection_id and user_id from config
thread_config = config.get("configurable", {})
collection_id = thread_config.get("collection_id")
user_id = thread_config.get("user_id")
if not collection_id or not user_id:
return "Error: collection_id and user_id are required in the config"
try:
# Create query input
input_data = QueryInput(
collection_id=collection_id,
query=query,
user_id=user_id,
top_k=6
)
response = await query_collection(input_data)
results = []
# Access response directly since it's a Pydantic model
for r in response.results:
result_dict = {
"text": r.text,
"distance": r.distance,
"metadata": {
"document_id": r.metadata.get("document_id"),
"chunk_index": r.metadata.get("location", {}).get("chunk_index")
}
}
results.append(result_dict)
return str(results)
except Exception as e:
print(e)
return f"Error querying documents: {e} PAUSE AND ASK USER FOR HELP"
async def query_documents_raw(
query: str,
config: RunnableConfig,
) -> SearchResult:
"""Use this tool to retrieve relevant data from the collection.
Args:
query: The search query to find relevant document passages
"""
# Get collection_id and user_id from config
thread_config = config.get("configurable", {})
collection_id = thread_config.get("collection_id")
user_id = thread_config.get("user_id")
if not collection_id or not user_id:
return "Error: collection_id and user_id are required in the config"
try:
# Create query input
input_data = QueryInput(
collection_id=collection_id,
query=query,
user_id=user_id,
top_k=6
)
response = await query_collection(input_data)
return response.results
except Exception as e:
print(e)
return f"Error querying documents: {e} PAUSE AND ASK USER FOR HELP"
memory = MemorySaver()
model = ChatOpenAI(model="gpt-4o-mini", streaming=True)
# Create a prompt template for formatting
prompt = ChatPromptTemplate.from_messages([
("system", "You are a helpful AI assistant. The current collection contains the following files: {collection_files}, use query_documents tool to answer user queries from the document. In case a summary is requested, create multiple queries for different plausible sections of the document"),
("placeholder", "{messages}"),
])
import requests
from requests.exceptions import RequestException, Timeout
import logging
from typing import Optional
# def get_collection_files(collection_id: str, user_id: str) -> str:
# """
# Synchronously get list of files in the specified collection using the external API
# with proper timeout and error handling.
# """
# try:
# url = "https://pvanand-documind-api-v2.hf.space/rag/get_collection_files"
# params = {
# "collection_id": collection_id,
# "user_id": user_id
# }
# headers = {
# 'accept': 'application/json'
# }
# logger.debug(f"Requesting collection files for user {user_id}, collection {collection_id}")
# # Set timeout to 5 seconds
# response = requests.post(url, params=params, headers=headers, data='', timeout=5)
# if response.status_code == 200:
# logger.info(f"Successfully retrieved collection files: {response.text[:100]}...")
# return response.text
# else:
# logger.error(f"API error (status {response.status_code}): {response.text}")
# return f"Error fetching files (status {response.status_code})"
# except Timeout:
# logger.error("Timeout while fetching collection files")
# return "Error: Request timed out"
# except RequestException as e:
# logger.error(f"Network error fetching collection files: {str(e)}")
# return f"Error: Network issue - {str(e)}"
# except Exception as e:
# logger.error(f"Error fetching collection files: {str(e)}", exc_info=True)
# return f"Error fetching files: {str(e)}"
def get_collection_files(collection_id: str, user_id: str) -> str:
"""Get list of files in the specified collection"""
try:
# Get the full collection name
collection_name = f"{user_id}_{collection_id}"
# Open the table and convert to pandas
table = db.open_table(collection_name)
df = table.to_pandas()
print(df.head())
# Get unique file names
unique_files = df['file_name'].unique()
# Join the file names into a string
return ", ".join(unique_files)
except Exception as e:
logging.error(f"Error getting collection files: {str(e)}")
return f"Error getting files: {str(e)}"
def format_for_model(state: dict, config: Optional[RunnableConfig] = None) -> list[BaseMessage]:
"""
Format the input state and config for the model.
Args:
state: The current state dictionary containing messages
config: Optional RunnableConfig containing thread configuration
Returns:
Formatted messages for the model
"""
# Get collection_id and user_id from config instead of state
thread_config = config.get("configurable", {}) if config else {}
collection_id = thread_config.get("collection_id")
user_id = thread_config.get("user_id")
try:
# Get files in the collection with timeout protection
if collection_id and user_id:
collection_files = get_collection_files(collection_id, user_id)
else:
collection_files = "No files available"
logger.info(f"Fetching collection for userid {user_id} and collection_id {collection_id} || Results: {collection_files[:100]}...")
# Format using the prompt template
return prompt.invoke({
"collection_files": collection_files,
"messages": state.get("messages", [])
})
except Exception as e:
logger.error(f"Error in format_for_model: {str(e)}", exc_info=True)
# Return a basic format if there's an error
return prompt.invoke({
"collection_files": "Error fetching files",
"messages": state.get("messages", [])
})
async def clean_tool_input(tool_input: str):
# Use regex to parse the first key and value
pattern = r"{\s*'([^']+)':\s*'([^']+)'"
match = re.search(pattern, tool_input)
if match:
key, value = match.groups()
return {key: value}
return [tool_input]
async def clean_tool_response(tool_output: str):
"""Clean and extract relevant information from tool response if it contains query_documents."""
if "query_documents" in tool_output:
try:
# First safely evaluate the string as a Python literal
import ast
print(tool_output)
# Extract the list string from the content
start = tool_output.find("[{")
end = tool_output.rfind("}]") + 2
if start >= 0 and end > 0:
list_str = tool_output[start:end]
# Convert string to Python object using ast.literal_eval
results = ast.literal_eval(list_str)
# Return only relevant fields
return [{"text": r["text"], "document_id": r["metadata"]["document_id"]}
for r in results]
except SyntaxError as e:
print(f"Syntax error in parsing: {e}")
return f"Error parsing document results: {str(e)}"
except Exception as e:
print(f"General error: {e}")
return f"Error processing results: {str(e)}"
return tool_output
agent = create_react_agent(
model,
tools=[query_documents],
checkpointer=memory,
state_modifier=format_for_model,
)
class ChatInput(BaseModel):
message: str
thread_id: Optional[str] = None
collection_id: Optional[str] = None
user_id: Optional[str] = None
@app.post("/chat")
async def chat(input_data: ChatInput):
thread_id = input_data.thread_id or str(uuid.uuid4())
config = {
"configurable": {
"thread_id": thread_id,
"collection_id": input_data.collection_id,
"user_id": input_data.user_id
}
}
input_message = HumanMessage(content=input_data.message)
async def generate():
async for event in agent.astream_events(
{"messages": [input_message]},
config,
version="v2"
):
kind = event["event"]
if kind == "on_chat_model_stream":
content = event["data"]["chunk"].content
if content:
yield f"{json.dumps({'type': 'token', 'content': content})}"
elif kind == "on_tool_start":
tool_input = str(event['data'].get('input', ''))
yield f"{json.dumps({'type': 'tool_start', 'tool': event['name'], 'input': tool_input})}"
elif kind == "on_tool_end":
tool_output = str(event['data'].get('output', ''))
yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': tool_output})}"
return EventSourceResponse(
generate(),
media_type="text/event-stream"
)
@app.post("/chat2")
async def chat2(input_data: ChatInput):
thread_id = input_data.thread_id or str(uuid.uuid4())
config = {
"configurable": {
"thread_id": thread_id,
"collection_id": input_data.collection_id,
"user_id": input_data.user_id
}
}
input_message = HumanMessage(content=input_data.message)
async def generate():
async for event in agent.astream_events(
{"messages": [input_message]},
config,
version="v2"
):
kind = event["event"]
if kind == "on_chat_model_stream":
content = event["data"]["chunk"].content
if content:
yield f"{json.dumps({'type': 'token', 'content': content})}"
elif kind == "on_tool_start":
tool_name = event['name']
tool_input = event['data'].get('input', '')
clean_input = await clean_tool_input(str(tool_input))
yield f"{json.dumps({'type': 'tool_start', 'tool': tool_name, 'inputs': clean_input})}"
elif kind == "on_tool_end":
if "query_documents" in event['name']:
print(event)
raw_output = await query_documents_raw(str(event['data'].get('input', '')), config)
try:
serializable_output = [
{
"text": result.text,
"distance": result.distance,
"metadata": result.metadata
}
for result in raw_output
]
yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': json.dumps(serializable_output)})}"
except Exception as e:
print(e)
yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': str(raw_output)})}"
else:
tool_name = event['name']
raw_output = str(event['data'].get('output', ''))
clean_output = await clean_tool_response(raw_output)
yield f"{json.dumps({'type': 'tool_end', 'tool': tool_name, 'output': clean_output})}"
return EventSourceResponse(
generate(),
media_type="text/event-stream"
)
@app.get("/health")
async def health_check():
return {"status": "healthy"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)