seekr / main.py
Hemang Thakur
Deploy project on Hugging Face Spaces
4279593
raw
history blame
30.2 kB
import os
import re
import asyncio
import json
import time
import logging
from typing import Any, Dict
from fastapi.staticfiles import StaticFiles
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
from dotenv import load_dotenv
from openai import RateLimitError
from anthropic import RateLimitError as AnthropicRateLimitError
from google.api_core.exceptions import ResourceExhausted
logger = logging.getLogger()
logger.setLevel(logging.INFO)
CONTEXT_LENGTH = 128000
BUFFER = 10000
MAX_TOKENS_ALLOWED = CONTEXT_LENGTH - BUFFER
# Per-session state
SESSION_STORE: Dict[str, Dict[str, Any]] = {}
# Format error message for SSE
def format_error_sse(event_type: str, data: str) -> str:
lines = data.splitlines()
sse_message = f"event: {event_type}\n"
for line in lines:
sse_message += f"data: {line}\n"
sse_message += "\n"
return sse_message
# Initialize the components
def initialize_components():
load_dotenv(override=True)
from src.search.search_engine import SearchEngine
from src.query_processing.query_processor import QueryProcessor
from src.rag.neo4j_graphrag import Neo4jGraphRAG
from src.evaluation.evaluator import Evaluator
from src.reasoning.reasoner import Reasoner
from src.crawl.crawler import CustomCrawler
from src.utils.api_key_manager import APIKeyManager
from src.query_processing.late_chunking.late_chunker import LateChunker
manager = APIKeyManager()
manager._reinit()
SESSION_STORE['search_engine'] = SearchEngine()
SESSION_STORE['query_processor'] = QueryProcessor()
SESSION_STORE['crawler'] = CustomCrawler(max_concurrent_requests=1000)
SESSION_STORE['graph_rag'] = Neo4jGraphRAG(num_workers=os.cpu_count() * 2)
SESSION_STORE['evaluator'] = Evaluator()
SESSION_STORE['reasoner'] = Reasoner()
SESSION_STORE['model'] = manager.get_llm()
SESSION_STORE['late_chunker'] = LateChunker()
SESSION_STORE["initialized"] = True
SESSION_STORE["session_id"] = None
async def process_query(user_query: str, sse_queue: asyncio.Queue):
state = SESSION_STORE
try:
category = await state["query_processor"].classify_query(user_query)
cat_lower = category.lower().strip()
if state["session_id"] is None:
state["session_id"] = await state["crawler"].create_session()
user_query = re.sub(r'category:.*', '', user_query, flags=re.IGNORECASE).strip()
if cat_lower == "internal knowledge base":
response = ""
async for chunk in state["reasoner"].reason(user_query):
response += chunk
await sse_queue.put(("token", chunk))
await sse_queue.put(("final_message", response))
SESSION_STORE["chat_history"].append({"query": user_query, "response": response})
await sse_queue.put(("action", {
"name": "evaluate",
"payload": {"query": user_query, "response": response}
}))
await sse_queue.put(("complete", "done"))
elif cat_lower == "simple external lookup":
await sse_queue.put(("step", "Searching..."))
optimized_query = await state['search_engine'].generate_optimized_query(user_query)
search_results = await state['search_engine'].search(
optimized_query,
num_results=3,
exclude_filetypes=["pdf"]
)
urls = [r.get('link', 'No URL') for r in search_results]
search_contents = await state['crawler'].fetch_page_contents(
urls,
user_query,
state["session_id"],
max_attempts=1
)
contents = ""
if search_contents:
for k, content in enumerate(search_contents, 1):
if isinstance(content, Exception):
print(f"Error fetching content: {content}")
elif content:
contents += f"Document {k}:\n{content}\n\n"
if len(contents.strip()) > 0:
await sse_queue.put(("step", "Generating Response..."))
token_count = state['model'].get_num_tokens(contents)
if token_count > MAX_TOKENS_ALLOWED:
contents = await state['late_chunker'].chunker(contents, user_query, MAX_TOKENS_ALLOWED)
await sse_queue.put(("sources_read", len(search_contents)))
response = ""
async for chunk in state["reasoner"].reason(user_query, contents):
response += chunk
await sse_queue.put(("token", chunk))
await sse_queue.put(("final_message", response))
SESSION_STORE["chat_history"].append({"query": user_query, "response": response})
await sse_queue.put(("action", {
"name": "sources",
"payload": {"search_results": search_results, "search_contents": search_contents}
}))
await sse_queue.put(("action", {
"name": "evaluate",
"payload": {"query": user_query, "contents": [contents], "response": response}
}))
await sse_queue.put(("complete", "done"))
else:
await sse_queue.put(("error", "No results found."))
elif cat_lower == "complex moderate decomposition":
current_search_results = []
current_search_contents = []
await sse_queue.put(("step", "Thinking..."))
start = time.time()
intent = await state['query_processor'].get_query_intent(user_query)
sub_queries, _ = await state['query_processor'].decompose_query(user_query, intent)
async def sub_query_task(sub_query):
try:
await sse_queue.put(("step", "Searching..."))
await sse_queue.put(("task", (sub_query, "RUNNING")))
optimized_query = await state['search_engine'].generate_optimized_query(sub_query)
search_results = await state['search_engine'].search(
optimized_query,
num_results=10,
exclude_filetypes=["pdf"]
)
filtered_urls = await state['search_engine'].filter_urls(
sub_query,
category,
search_results
)
current_search_results.extend(filtered_urls)
urls = [r.get('link', 'No URL') for r in filtered_urls]
search_contents = await state['crawler'].fetch_page_contents(
urls,
sub_query,
state["session_id"],
max_attempts=1
)
current_search_contents.extend(search_contents)
contents = ""
if search_contents:
for k, c in enumerate(search_contents, 1):
if isinstance(c, Exception):
logger.info(f"Error fetching content: {c}")
elif c:
contents += f"Document {k}:\n{c}\n\n"
if len(contents.strip()) > 0:
await sse_queue.put(("task", (sub_query, "DONE")))
else:
await sse_queue.put(("task", (sub_query, "FAILED")))
return contents
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError):
await sse_queue.put(("task", (sub_query, "FAILED")))
return ""
tasks = [sub_query_task(sub_query) for sub_query in sub_queries]
results = await asyncio.gather(*tasks)
end = time.time()
contents = "\n\n".join(r for r in results if r.strip())
unique_results = []
seen = set()
for entry in current_search_results:
link = entry["link"]
if link not in seen:
seen.add(link)
unique_results.append(entry)
current_search_results = unique_results
current_search_contents = list(set(current_search_contents))
if len(contents.strip()) > 0:
await sse_queue.put(("step", "Generating Response..."))
token_count = state['model'].get_num_tokens(contents)
if token_count > MAX_TOKENS_ALLOWED:
contents = await state['late_chunker'].chunker(
text=contents,
query=user_query,
max_tokens=MAX_TOKENS_ALLOWED
)
logger.info(f"Number of tokens in the answer: {token_count}")
logger.info(f"Number of tokens in the content: {state['model'].get_num_tokens(contents)}")
await sse_queue.put(("sources_read", len(current_search_contents)))
response = ""
is_first_chunk = True
async for chunk in state['reasoner'].reason(user_query, contents):
if is_first_chunk:
await sse_queue.put(("step", f"Thought and searched for {int(end - start)} seconds"))
is_first_chunk = False
response += chunk
await sse_queue.put(("token", chunk))
await sse_queue.put(("final_message", response))
SESSION_STORE["chat_history"].append({"query": user_query, "response": response})
await sse_queue.put(("action", {
"name": "sources",
"payload": {
"search_results": current_search_results,
"search_contents": current_search_contents
}
}))
await sse_queue.put(("action", {
"name": "evaluate",
"payload": {"query": user_query, "contents": [contents], "response": response}
}))
await sse_queue.put(("complete", "done"))
else:
await sse_queue.put(("error", "No results found."))
elif cat_lower == "complex advanced decomposition":
current_search_results = []
current_search_contents = []
await sse_queue.put(("step", "Thinking..."))
start = time.time()
main_query_intent = await state['query_processor'].get_query_intent(user_query)
sub_queries, _ = await state['query_processor'].decompose_query(user_query, main_query_intent)
await sse_queue.put(("step", "Searching..."))
async def sub_query_task(sub_query):
try:
async def sub_sub_query_task(sub_sub_query):
optimized_query = await state['search_engine'].generate_optimized_query(sub_sub_query)
search_results = await state['search_engine'].search(
optimized_query,
num_results=10,
exclude_filetypes=["pdf"]
)
filtered_urls = await state['search_engine'].filter_urls(
sub_sub_query,
category,
search_results
)
current_search_results.extend(filtered_urls)
urls = [r.get('link', 'No URL') for r in filtered_urls]
search_contents = await state['crawler'].fetch_page_contents(
urls,
sub_sub_query,
state["session_id"],
max_attempts=1,
timeout=20
)
current_search_contents.extend(search_contents)
contents = ""
if search_contents:
for k, c in enumerate(search_contents, 1):
if isinstance(c, Exception):
logger.info(f"Error fetching content: {c}")
elif c:
contents += f"Document {k}:\n{c}\n\n"
return contents
await sse_queue.put(("task", (sub_query, "RUNNING")))
sub_sub_queries, _ = await state['query_processor'].decompose_query(sub_query)
tasks = [sub_sub_query_task(sub_sub_query) for sub_sub_query in sub_sub_queries]
results = await asyncio.gather(*tasks)
if any(result.strip() for result in results):
await sse_queue.put(("task", (sub_query, "DONE")))
else:
await sse_queue.put(("task", (sub_query, "FAILED")))
return results
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError):
await sse_queue.put(("task", (sub_query, "FAILED")))
return []
tasks = [sub_query_task(sub_query) for sub_query in sub_queries]
results = await asyncio.gather(*tasks)
end = time.time()
previous_contents = []
for result in results:
if result:
for content in result:
if isinstance(content, str) and len(content.strip()) > 0:
previous_contents.append(content)
contents = "\n\n".join(previous_contents)
unique_results = []
seen = set()
for entry in current_search_results:
link = entry["link"]
if link not in seen:
seen.add(link)
unique_results.append(entry)
current_search_results = unique_results
current_search_contents = list(set(current_search_contents))
if len(contents.strip()) > 0:
await sse_queue.put(("step", "Generating Response..."))
token_count = state['model'].get_num_tokens(contents)
if token_count > MAX_TOKENS_ALLOWED:
contents = await state['late_chunker'].chunker(
text=contents,
query=user_query,
max_tokens=MAX_TOKENS_ALLOWED
)
logger.info(f"Number of tokens in the answer: {token_count}")
logger.info(f"Number of tokens in the content: {state['model'].get_num_tokens(contents)}")
await sse_queue.put(("sources_read", len(current_search_contents)))
response = ""
is_first_chunk = True
async for chunk in state['reasoner'].reason(user_query, contents):
if is_first_chunk:
await sse_queue.put(("step", f"Thought and searched for {int(end - start)} seconds"))
is_first_chunk = False
response += chunk
await sse_queue.put(("token", chunk))
await sse_queue.put(("final_message", response))
SESSION_STORE["chat_history"].append({"query": user_query, "response": response})
await sse_queue.put(("action", {
"name": "sources",
"payload": {
"search_results": current_search_results,
"search_contents": current_search_contents
}
}))
await sse_queue.put(("action", {
"name": "evaluate",
"payload": {"query": user_query, "contents": [contents], "response": response}
}))
await sse_queue.put(("complete", "done"))
else:
await sse_queue.put(("error", "No results found."))
elif cat_lower == "extensive research dynamic structuring":
current_search_results = []
current_search_contents = []
match = re.search(
r"^This is the previous context of the conversation:\s*.*?\s*Current Query:\s*(.*)$",
user_query,
flags=re.DOTALL | re.MULTILINE
)
if match:
user_query = match.group(1)
await sse_queue.put(("step", "Thinking..."))
await asyncio.sleep(0.01) # Sleep for a short time to allow the message to be sent
async def on_event_callback(event_type, data):
if event_type == "graph_operation":
if data["operation_type"] == "creating_new_graph":
await sse_queue.put(("step", "Creating New Graph..."))
elif data["operation_type"] == "modifying_existing_graph":
await sse_queue.put(("step", "Modifying Existing Graph..."))
elif data["operation_type"] == "loading_existing_graph":
await sse_queue.put(("step", "Loading Existing Graph..."))
elif event_type == "sub_query_created":
sub_query = data["sub_query"]
await sse_queue.put(("task", (sub_query, "RUNNING")))
elif event_type == "search_process_started":
await sse_queue.put(("step", "Searching..."))
elif event_type == "sub_query_processed":
sub_query = data["sub_query"]
await sse_queue.put(("task", (sub_query, "DONE")))
elif event_type == "sub_query_failed":
sub_query = data["sub_query"]
await sse_queue.put(("task", (sub_query, "FAILED")))
elif event_type == "search_results_filtered":
current_search_results.extend(data["filtered_urls"])
filtered_urls = data["filtered_urls"]
current_search_results.extend(filtered_urls)
elif event_type == "search_contents_fetched":
current_search_contents.extend(data["contents"])
contents = data["contents"]
current_search_contents.extend(contents)
state['graph_rag'].set_on_event_callback(on_event_callback)
start = time.time()
state['graph_rag'].initialize_schema()
await state['graph_rag'].process_graph(
user_query,
similarity_threshold=0.8,
relevance_threshold=0.8,
max_tokens_allowed=MAX_TOKENS_ALLOWED
)
end = time.time()
unique_results = []
seen = set()
for entry in current_search_results:
link = entry["link"]
if link not in seen:
seen.add(link)
unique_results.append(entry)
current_search_results = unique_results
current_search_contents = list(set(current_search_contents))
await sse_queue.put(("step", "Generating Response..."))
answer = state['graph_rag'].query_graph(user_query)
if answer:
token_count = state['model'].get_num_tokens(answer)
if token_count > MAX_TOKENS_ALLOWED:
answer = await state['late_chunker'].chunker(
text=answer,
query=user_query,
max_tokens=MAX_TOKENS_ALLOWED
)
logger.info(f"Number of tokens in the answer: {token_count}")
logger.info(f"Number of tokens in the content: {state['model'].get_num_tokens(answer)}")
await sse_queue.put(("sources_read", len(current_search_contents)))
response = ""
is_first_chunk = True
async for chunk in state['reasoner'].reason(user_query, answer):
if is_first_chunk:
await sse_queue.put(("step", f"Thought and searched for {int(end - start)} seconds"))
is_first_chunk = False
response += chunk
await sse_queue.put(("token", chunk))
await sse_queue.put(("final_message", response))
SESSION_STORE["chat_history"].append({"query": user_query, "response": response})
await sse_queue.put(("action", {
"name": "sources",
"payload": {"search_results": current_search_results, "search_contents": current_search_contents},
}))
await sse_queue.put(("action", {
"name": "graph",
"payload": {"query": user_query},
}))
await sse_queue.put(("action", {
"name": "evaluate",
"payload": {"query": user_query, "contents": [answer], "response": response},
}))
await sse_queue.put(("complete", "done"))
else:
await sse_queue.put(("error", "No results found."))
else:
await sse_queue.put(("final_message", "I'm not sure how to handle your query."))
except Exception as e:
await sse_queue.put(("error", str(e)))
# Create a FastAPI app
app = FastAPI()
# Define allowed origins
origins = [
"http://localhost:3000",
"http://localhost:7860"
"http://localhost:8000",
"http://localhost"
]
# Add the CORS middleware to your FastAPI app
app.add_middleware(
CORSMiddleware,
allow_origins=origins, # Allows only these origins
allow_credentials=True,
allow_methods=["*"], # Allows all HTTP methods (GET, POST, etc.)
allow_headers=["*"], # Allows all headers
)
# Serve the React app (the production build) at the root URL.
app.mount("/static", StaticFiles(directory="static/static", html=True), name="static")
# Catch-all route for frontend paths.
@app.get("/{full_path:path}")
async def serve_frontend(full_path: str, request: Request):
if full_path.startswith("action") or full_path in ["settings", "message-sse", "stop"]:
raise HTTPException(status_code=404, detail="Not Found")
index_path = os.path.join("static", "index.html")
if not os.path.exists(index_path):
raise HTTPException(status_code=500, detail="Frontend build not found")
return FileResponse(index_path)
# Define the routes for the FastAPI app
# Define the route for sources action to display search results
@app.post("/action/sources")
def action_sources(payload: Dict[str, Any]) -> Dict[str, Any]:
try:
search_contents = payload.get("search_contents", [])
search_results = payload.get("search_results", [])
sources = []
word_limit = 15 # Maximum number of words for the description
for result, contents in zip(search_results, search_contents):
if contents:
title = result.get('title', 'No Title')
link = result.get('link', 'No URL')
snippet = result.get('snippet', 'No snippet')
cleaned = re.sub(r'<[^>]+>|\[\/?.*?\]', '', snippet)
words = cleaned.split()
if len(words) > word_limit:
description = " ".join(words[:word_limit]) + "..."
else:
description = " ".join(words)
source_obj = {
"title": title,
"link": link,
"description": description
}
sources.append(source_obj)
return {"result": sources}
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
# Define the route for graph action to display the graph
@app.post("/action/graph")
def action_graph(payload: Dict[str, Any]) -> Dict[str, Any]:
state = SESSION_STORE
try:
q = payload.get("query", "")
html_str = state['graph_rag'].display_graph(q)
return {"result": html_str}
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
# Define the route for evaluate action to display evaluation results
@app.post("/action/evaluate")
async def action_evaluate(payload: Dict[str, Any]) -> Dict[str, Any]:
try:
query = payload.get("query", "")
contents = payload.get("contents", [])
response = payload.get("response", "")
metrics = payload.get("metrics", [])
state = SESSION_STORE
evaluator = state["evaluator"]
result = await evaluator.evaluate_response(query, response, contents, include_metrics=metrics)
return {"result": result}
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
@app.post("/settings")
async def update_settings(data: Dict[str, Any]):
from src.helpers.helper import (
prepare_provider_key_updates,
prepare_proxy_list_updates,
update_env_vars
)
provider = data.get("Model_Provider", "").strip()
model_name = data.get("Model_Name", "").strip()
multiple_api_keys = data.get("Model_API_Keys", "").strip()
brave_api_key = data.get("Brave_Search_API_Key", "").strip()
proxy_list = data.get("Proxy_List", "").strip()
neo4j_url = data.get("Neo4j_URL", "").strip()
neo4j_username = data.get("Neo4j_Username", "").strip()
neo4j_password = data.get("Neo4j_Password", "").strip()
model_temperature = str(data.get("Model_Temperature", 0.0))
model_top_p = str(data.get("Model_Top_P", 1.0))
prov_lower = provider.lower()
key_updates = prepare_provider_key_updates(prov_lower, multiple_api_keys)
env_updates = {}
env_updates.update(key_updates)
px = prepare_proxy_list_updates(proxy_list)
if px:
env_updates.update(px)
env_updates["BRAVE_API_KEY"] = brave_api_key
env_updates["NEO4J_URI"] = neo4j_url
env_updates["NEO4J_USER"] = neo4j_username
env_updates["NEO4J_PASSWORD"] = neo4j_password
env_updates["MODEL_PROVIDER"] = prov_lower
env_updates["MODEL_NAME"] = model_name
env_updates["MODEL_TEMPERATURE"] = model_temperature
env_updates["MODEL_TOP_P"] = model_top_p
update_env_vars(env_updates)
load_dotenv(override=True)
initialize_components()
return {"success": True}
@app.on_event("startup")
def init_chat():
if not SESSION_STORE:
print("Initializing chat...")
SESSION_STORE["settings_saved"] = False
SESSION_STORE["session_id"] = None
SESSION_STORE["chat_history"] = []
print("Chat initialized!")
return {"sucess": True}
else:
print("Chat already initialized!")
return {"success": False}
@app.get("/message-sse")
async def sse_message(request: Request, user_message: str):
state = SESSION_STORE
sse_queue = asyncio.Queue()
async def event_generator():
# Build the prompt
context = state["chat_history"][-5:]
if context:
prompt = \
f"""This is the previous context of the conversation:
{context}
Current Query:
{user_message}"""
else:
prompt = user_message
task = asyncio.create_task(process_query(prompt, sse_queue))
state["process_task"] = task
while True:
if await request.is_disconnected():
task.cancel()
break
try:
event_type, data = await asyncio.wait_for(sse_queue.get(), timeout=5)
if event_type == "token":
yield f"event: token\ndata: {data}\n\n"
elif event_type == "final_message":
yield f"event: final_message\ndata: {data}\n\n"
elif event_type == "error":
yield format_error_sse("error", data)
elif event_type == "step":
yield f"event: step\ndata: {data}\n\n"
elif event_type == "task":
subq, status = data
j = {"task": subq, "status": status}
yield f"event: task\ndata: {json.dumps(j)}\n\n"
elif event_type == "sources_read":
yield f"event: sources_read\ndata: {data}\n\n"
elif event_type == "action":
yield f"event: action\ndata: {json.dumps(data)}\n\n"
elif event_type == "complete":
yield f"event: complete\ndata: {data}\n\n"
break
else:
yield f"event: message\ndata: {data}\n\n"
except asyncio.TimeoutError:
if task.done():
break
continue
except asyncio.CancelledError:
break
if not task.done():
task.cancel()
if "process_task" in state:
del state["process_task"]
return StreamingResponse(event_generator(), media_type="text/event-stream")
@app.post("/stop")
def stop():
state = SESSION_STORE
if "process_task" in state:
state["process_task"].cancel()
del state["process_task"]
return {"message": "Stopped task manually"}