Spaces:
Sleeping
Sleeping
#DOCS | |
# https://langchain-ai.github.io/langgraph/reference/prebuilt/#langgraph.prebuilt.chat_agent_executor.create_react_agent | |
import uuid | |
from fastapi import FastAPI | |
from fastapi.responses import StreamingResponse | |
from langchain_core.messages import ( | |
BaseMessage, | |
HumanMessage, | |
SystemMessage, | |
trim_messages, | |
) | |
from langchain_core.tools import tool | |
from langchain_openai import ChatOpenAI | |
from langgraph.checkpoint.memory import MemorySaver | |
from langgraph.prebuilt import create_react_agent | |
from pydantic import BaseModel | |
import json | |
from typing import Optional, Annotated | |
from langchain_core.runnables import RunnableConfig | |
from langgraph.prebuilt import InjectedState | |
from document_rag_router import router as document_rag_router | |
from document_rag_router import QueryInput, query_collection, SearchResult,db | |
from fastapi import HTTPException | |
import requests | |
from sse_starlette.sse import EventSourceResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
import re | |
import os | |
from langchain_core.prompts import ChatPromptTemplate | |
import logging.config | |
# Configure logging at application startup | |
logging.config.dictConfig({ | |
"version": 1, | |
"disable_existing_loggers": False, | |
"formatters": { | |
"default": { | |
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
"datefmt": "%Y-%m-%d %H:%M:%S", | |
} | |
}, | |
"handlers": { | |
"console": { | |
"class": "logging.StreamHandler", | |
"stream": "ext://sys.stdout", | |
"formatter": "default", | |
"level": "DEBUG", | |
} | |
}, | |
"root": { | |
"level": "DEBUG", | |
"handlers": ["console"] | |
}, | |
"loggers": { | |
"uvicorn": {"handlers": ["console"], "level": "DEBUG"}, | |
"fastapi": {"handlers": ["console"], "level": "DEBUG"} | |
} | |
}) | |
# Create logger instance | |
logger = logging.getLogger(__name__) | |
app = FastAPI() | |
app.include_router(document_rag_router) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
def get_current_files(): | |
"""Get list of files in current directory""" | |
try: | |
files = os.listdir('.') | |
return ", ".join(files) | |
except Exception as e: | |
return f"Error getting files: {str(e)}" | |
def get_user_age(name: str) -> str: | |
"""Use this tool to find the user's age.""" | |
if "bob" in name.lower(): | |
return "42 years old" | |
return "41 years old" | |
async def query_documents( | |
query: str, | |
config: RunnableConfig, | |
) -> str: | |
"""Use this tool to retrieve relevant data from the collection. | |
Args: | |
query: The search query to find relevant document passages | |
""" | |
# Get collection_id and user_id from config | |
thread_config = config.get("configurable", {}) | |
collection_id = thread_config.get("collection_id") | |
user_id = thread_config.get("user_id") | |
if not collection_id or not user_id: | |
return "Error: collection_id and user_id are required in the config" | |
try: | |
# Create query input | |
input_data = QueryInput( | |
collection_id=collection_id, | |
query=query, | |
user_id=user_id, | |
top_k=6 | |
) | |
response = await query_collection(input_data) | |
results = [] | |
# Access response directly since it's a Pydantic model | |
for r in response.results: | |
result_dict = { | |
"text": r.text, | |
"distance": r.distance, | |
"metadata": { | |
"document_id": r.metadata.get("document_id"), | |
"chunk_index": r.metadata.get("location", {}).get("chunk_index") | |
} | |
} | |
results.append(result_dict) | |
return str(results) | |
except Exception as e: | |
print(e) | |
return f"Error querying documents: {e} PAUSE AND ASK USER FOR HELP" | |
async def query_documents_raw( | |
query: str, | |
config: RunnableConfig, | |
) -> SearchResult: | |
"""Use this tool to retrieve relevant data from the collection. | |
Args: | |
query: The search query to find relevant document passages | |
""" | |
# Get collection_id and user_id from config | |
thread_config = config.get("configurable", {}) | |
collection_id = thread_config.get("collection_id") | |
user_id = thread_config.get("user_id") | |
if not collection_id or not user_id: | |
return "Error: collection_id and user_id are required in the config" | |
try: | |
# Create query input | |
input_data = QueryInput( | |
collection_id=collection_id, | |
query=query, | |
user_id=user_id, | |
top_k=6 | |
) | |
response = await query_collection(input_data) | |
return response.results | |
except Exception as e: | |
print(e) | |
return f"Error querying documents: {e} PAUSE AND ASK USER FOR HELP" | |
memory = MemorySaver() | |
model = ChatOpenAI(model="gpt-4o-mini", streaming=True) | |
# Create a prompt template for formatting | |
prompt = ChatPromptTemplate.from_messages([ | |
("system", "You are a helpful AI assistant. The current collection contains the following files: {collection_files}, use query_documents tool to answer user queries from the document. In case a summary is requested, create multiple queries for different plausible sections of the document"), | |
("placeholder", "{messages}"), | |
]) | |
import requests | |
from requests.exceptions import RequestException, Timeout | |
import logging | |
from typing import Optional | |
# def get_collection_files(collection_id: str, user_id: str) -> str: | |
# """ | |
# Synchronously get list of files in the specified collection using the external API | |
# with proper timeout and error handling. | |
# """ | |
# try: | |
# url = "https://pvanand-documind-api-v2.hf.space/rag/get_collection_files" | |
# params = { | |
# "collection_id": collection_id, | |
# "user_id": user_id | |
# } | |
# headers = { | |
# 'accept': 'application/json' | |
# } | |
# logger.debug(f"Requesting collection files for user {user_id}, collection {collection_id}") | |
# # Set timeout to 5 seconds | |
# response = requests.post(url, params=params, headers=headers, data='', timeout=5) | |
# if response.status_code == 200: | |
# logger.info(f"Successfully retrieved collection files: {response.text[:100]}...") | |
# return response.text | |
# else: | |
# logger.error(f"API error (status {response.status_code}): {response.text}") | |
# return f"Error fetching files (status {response.status_code})" | |
# except Timeout: | |
# logger.error("Timeout while fetching collection files") | |
# return "Error: Request timed out" | |
# except RequestException as e: | |
# logger.error(f"Network error fetching collection files: {str(e)}") | |
# return f"Error: Network issue - {str(e)}" | |
# except Exception as e: | |
# logger.error(f"Error fetching collection files: {str(e)}", exc_info=True) | |
# return f"Error fetching files: {str(e)}" | |
def get_collection_files(collection_id: str, user_id: str) -> str: | |
"""Get list of files in the specified collection""" | |
try: | |
# Get the full collection name | |
collection_name = f"{user_id}_{collection_id}" | |
# Open the table and convert to pandas | |
table = db.open_table(collection_name) | |
df = table.to_pandas() | |
print(df.head()) | |
# Get unique file names | |
unique_files = df['file_name'].unique() | |
# Join the file names into a string | |
return ", ".join(unique_files) | |
except Exception as e: | |
logging.error(f"Error getting collection files: {str(e)}") | |
return f"Error getting files: {str(e)}" | |
def format_for_model(state: dict, config: Optional[RunnableConfig] = None) -> list[BaseMessage]: | |
""" | |
Format the input state and config for the model. | |
Args: | |
state: The current state dictionary containing messages | |
config: Optional RunnableConfig containing thread configuration | |
Returns: | |
Formatted messages for the model | |
""" | |
# Get collection_id and user_id from config instead of state | |
thread_config = config.get("configurable", {}) if config else {} | |
collection_id = thread_config.get("collection_id") | |
user_id = thread_config.get("user_id") | |
try: | |
# Get files in the collection with timeout protection | |
if collection_id and user_id: | |
collection_files = get_collection_files(collection_id, user_id) | |
else: | |
collection_files = "No files available" | |
logger.info(f"Fetching collection for userid {user_id} and collection_id {collection_id} || Results: {collection_files[:100]}...") | |
# Format using the prompt template | |
return prompt.invoke({ | |
"collection_files": collection_files, | |
"messages": state.get("messages", []) | |
}) | |
except Exception as e: | |
logger.error(f"Error in format_for_model: {str(e)}", exc_info=True) | |
# Return a basic format if there's an error | |
return prompt.invoke({ | |
"collection_files": "Error fetching files", | |
"messages": state.get("messages", []) | |
}) | |
async def clean_tool_input(tool_input: str): | |
# Use regex to parse the first key and value | |
pattern = r"{\s*'([^']+)':\s*'([^']+)'" | |
match = re.search(pattern, tool_input) | |
if match: | |
key, value = match.groups() | |
return {key: value} | |
return [tool_input] | |
async def clean_tool_response(tool_output: str): | |
"""Clean and extract relevant information from tool response if it contains query_documents.""" | |
if "query_documents" in tool_output: | |
try: | |
# First safely evaluate the string as a Python literal | |
import ast | |
print(tool_output) | |
# Extract the list string from the content | |
start = tool_output.find("[{") | |
end = tool_output.rfind("}]") + 2 | |
if start >= 0 and end > 0: | |
list_str = tool_output[start:end] | |
# Convert string to Python object using ast.literal_eval | |
results = ast.literal_eval(list_str) | |
# Return only relevant fields | |
return [{"text": r["text"], "document_id": r["metadata"]["document_id"]} | |
for r in results] | |
except SyntaxError as e: | |
print(f"Syntax error in parsing: {e}") | |
return f"Error parsing document results: {str(e)}" | |
except Exception as e: | |
print(f"General error: {e}") | |
return f"Error processing results: {str(e)}" | |
return tool_output | |
agent = create_react_agent( | |
model, | |
tools=[query_documents], | |
checkpointer=memory, | |
state_modifier=format_for_model, | |
) | |
class ChatInput(BaseModel): | |
message: str | |
thread_id: Optional[str] = None | |
collection_id: Optional[str] = None | |
user_id: Optional[str] = None | |
async def chat(input_data: ChatInput): | |
thread_id = input_data.thread_id or str(uuid.uuid4()) | |
config = { | |
"configurable": { | |
"thread_id": thread_id, | |
"collection_id": input_data.collection_id, | |
"user_id": input_data.user_id | |
} | |
} | |
input_message = HumanMessage(content=input_data.message) | |
async def generate(): | |
async for event in agent.astream_events( | |
{"messages": [input_message]}, | |
config, | |
version="v2" | |
): | |
kind = event["event"] | |
if kind == "on_chat_model_stream": | |
content = event["data"]["chunk"].content | |
if content: | |
yield f"{json.dumps({'type': 'token', 'content': content})}" | |
elif kind == "on_tool_start": | |
tool_input = str(event['data'].get('input', '')) | |
yield f"{json.dumps({'type': 'tool_start', 'tool': event['name'], 'input': tool_input})}" | |
elif kind == "on_tool_end": | |
tool_output = str(event['data'].get('output', '')) | |
yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': tool_output})}" | |
return EventSourceResponse( | |
generate(), | |
media_type="text/event-stream" | |
) | |
async def chat2(input_data: ChatInput): | |
thread_id = input_data.thread_id or str(uuid.uuid4()) | |
config = { | |
"configurable": { | |
"thread_id": thread_id, | |
"collection_id": input_data.collection_id, | |
"user_id": input_data.user_id | |
} | |
} | |
input_message = HumanMessage(content=input_data.message) | |
async def generate(): | |
async for event in agent.astream_events( | |
{"messages": [input_message]}, | |
config, | |
version="v2" | |
): | |
kind = event["event"] | |
if kind == "on_chat_model_stream": | |
content = event["data"]["chunk"].content | |
if content: | |
yield f"{json.dumps({'type': 'token', 'content': content})}" | |
elif kind == "on_tool_start": | |
tool_name = event['name'] | |
tool_input = event['data'].get('input', '') | |
clean_input = await clean_tool_input(str(tool_input)) | |
yield f"{json.dumps({'type': 'tool_start', 'tool': tool_name, 'inputs': clean_input})}" | |
elif kind == "on_tool_end": | |
if "query_documents" in event['name']: | |
print(event) | |
raw_output = await query_documents_raw(str(event['data'].get('input', '')), config) | |
try: | |
serializable_output = [ | |
{ | |
"text": result.text, | |
"distance": result.distance, | |
"metadata": result.metadata | |
} | |
for result in raw_output | |
] | |
yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': json.dumps(serializable_output)})}" | |
except Exception as e: | |
print(e) | |
yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': str(raw_output)})}" | |
else: | |
tool_name = event['name'] | |
raw_output = str(event['data'].get('output', '')) | |
clean_output = await clean_tool_response(raw_output) | |
yield f"{json.dumps({'type': 'tool_end', 'tool': tool_name, 'output': clean_output})}" | |
return EventSourceResponse( | |
generate(), | |
media_type="text/event-stream" | |
) | |
async def health_check(): | |
return {"status": "healthy"} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) |