Spaces:
Paused
Paused
import os | |
import re | |
import asyncio | |
import json | |
import time | |
import logging | |
from typing import Any, Dict | |
from fastapi.staticfiles import StaticFiles | |
from fastapi import FastAPI, Request, HTTPException | |
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from dotenv import load_dotenv | |
from openai import RateLimitError | |
from anthropic import RateLimitError as AnthropicRateLimitError | |
from google.api_core.exceptions import ResourceExhausted | |
logger = logging.getLogger() | |
logger.setLevel(logging.INFO) | |
CONTEXT_LENGTH = 128000 | |
BUFFER = 10000 | |
MAX_TOKENS_ALLOWED = CONTEXT_LENGTH - BUFFER | |
# Per-session state | |
SESSION_STORE: Dict[str, Dict[str, Any]] = {} | |
# Format error message for SSE | |
def format_error_sse(event_type: str, data: str) -> str: | |
lines = data.splitlines() | |
sse_message = f"event: {event_type}\n" | |
for line in lines: | |
sse_message += f"data: {line}\n" | |
sse_message += "\n" | |
return sse_message | |
# Initialize the components | |
def initialize_components(): | |
load_dotenv(override=True) | |
from src.search.search_engine import SearchEngine | |
from src.query_processing.query_processor import QueryProcessor | |
from src.rag.neo4j_graphrag import Neo4jGraphRAG | |
from src.evaluation.evaluator import Evaluator | |
from src.reasoning.reasoner import Reasoner | |
from src.crawl.crawler import CustomCrawler | |
from src.utils.api_key_manager import APIKeyManager | |
from src.query_processing.late_chunking.late_chunker import LateChunker | |
manager = APIKeyManager() | |
manager._reinit() | |
SESSION_STORE['search_engine'] = SearchEngine() | |
SESSION_STORE['query_processor'] = QueryProcessor() | |
SESSION_STORE['crawler'] = CustomCrawler(max_concurrent_requests=1000) | |
SESSION_STORE['graph_rag'] = Neo4jGraphRAG(num_workers=os.cpu_count() * 2) | |
SESSION_STORE['evaluator'] = Evaluator() | |
SESSION_STORE['reasoner'] = Reasoner() | |
SESSION_STORE['model'] = manager.get_llm() | |
SESSION_STORE['late_chunker'] = LateChunker() | |
SESSION_STORE["initialized"] = True | |
SESSION_STORE["session_id"] = None | |
async def process_query(user_query: str, sse_queue: asyncio.Queue): | |
state = SESSION_STORE | |
try: | |
category = await state["query_processor"].classify_query(user_query) | |
cat_lower = category.lower().strip() | |
if state["session_id"] is None: | |
state["session_id"] = await state["crawler"].create_session() | |
user_query = re.sub(r'category:.*', '', user_query, flags=re.IGNORECASE).strip() | |
if cat_lower == "internal knowledge base": | |
response = "" | |
async for chunk in state["reasoner"].reason(user_query): | |
response += chunk | |
await sse_queue.put(("token", chunk)) | |
await sse_queue.put(("final_message", response)) | |
SESSION_STORE["chat_history"].append({"query": user_query, "response": response}) | |
await sse_queue.put(("action", { | |
"name": "evaluate", | |
"payload": {"query": user_query, "response": response} | |
})) | |
await sse_queue.put(("complete", "done")) | |
elif cat_lower == "simple external lookup": | |
await sse_queue.put(("step", "Searching...")) | |
optimized_query = await state['search_engine'].generate_optimized_query(user_query) | |
search_results = await state['search_engine'].search( | |
optimized_query, | |
num_results=3, | |
exclude_filetypes=["pdf"] | |
) | |
urls = [r.get('link', 'No URL') for r in search_results] | |
search_contents = await state['crawler'].fetch_page_contents( | |
urls, | |
user_query, | |
state["session_id"], | |
max_attempts=1 | |
) | |
contents = "" | |
if search_contents: | |
for k, content in enumerate(search_contents, 1): | |
if isinstance(content, Exception): | |
print(f"Error fetching content: {content}") | |
elif content: | |
contents += f"Document {k}:\n{content}\n\n" | |
if len(contents.strip()) > 0: | |
await sse_queue.put(("step", "Generating Response...")) | |
token_count = state['model'].get_num_tokens(contents) | |
if token_count > MAX_TOKENS_ALLOWED: | |
contents = await state['late_chunker'].chunker(contents, user_query, MAX_TOKENS_ALLOWED) | |
await sse_queue.put(("sources_read", len(search_contents))) | |
response = "" | |
async for chunk in state["reasoner"].reason(user_query, contents): | |
response += chunk | |
await sse_queue.put(("token", chunk)) | |
await sse_queue.put(("final_message", response)) | |
SESSION_STORE["chat_history"].append({"query": user_query, "response": response}) | |
await sse_queue.put(("action", { | |
"name": "sources", | |
"payload": {"search_results": search_results, "search_contents": search_contents} | |
})) | |
await sse_queue.put(("action", { | |
"name": "evaluate", | |
"payload": {"query": user_query, "contents": [contents], "response": response} | |
})) | |
await sse_queue.put(("complete", "done")) | |
else: | |
await sse_queue.put(("error", "No results found.")) | |
elif cat_lower == "complex moderate decomposition": | |
current_search_results = [] | |
current_search_contents = [] | |
await sse_queue.put(("step", "Thinking...")) | |
start = time.time() | |
intent = await state['query_processor'].get_query_intent(user_query) | |
sub_queries, _ = await state['query_processor'].decompose_query(user_query, intent) | |
async def sub_query_task(sub_query): | |
try: | |
await sse_queue.put(("step", "Searching...")) | |
await sse_queue.put(("task", (sub_query, "RUNNING"))) | |
optimized_query = await state['search_engine'].generate_optimized_query(sub_query) | |
search_results = await state['search_engine'].search( | |
optimized_query, | |
num_results=10, | |
exclude_filetypes=["pdf"] | |
) | |
filtered_urls = await state['search_engine'].filter_urls( | |
sub_query, | |
category, | |
search_results | |
) | |
current_search_results.extend(filtered_urls) | |
urls = [r.get('link', 'No URL') for r in filtered_urls] | |
search_contents = await state['crawler'].fetch_page_contents( | |
urls, | |
sub_query, | |
state["session_id"], | |
max_attempts=1 | |
) | |
current_search_contents.extend(search_contents) | |
contents = "" | |
if search_contents: | |
for k, c in enumerate(search_contents, 1): | |
if isinstance(c, Exception): | |
logger.info(f"Error fetching content: {c}") | |
elif c: | |
contents += f"Document {k}:\n{c}\n\n" | |
if len(contents.strip()) > 0: | |
await sse_queue.put(("task", (sub_query, "DONE"))) | |
else: | |
await sse_queue.put(("task", (sub_query, "FAILED"))) | |
return contents | |
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError): | |
await sse_queue.put(("task", (sub_query, "FAILED"))) | |
return "" | |
tasks = [sub_query_task(sub_query) for sub_query in sub_queries] | |
results = await asyncio.gather(*tasks) | |
end = time.time() | |
contents = "\n\n".join(r for r in results if r.strip()) | |
unique_results = [] | |
seen = set() | |
for entry in current_search_results: | |
link = entry["link"] | |
if link not in seen: | |
seen.add(link) | |
unique_results.append(entry) | |
current_search_results = unique_results | |
current_search_contents = list(set(current_search_contents)) | |
if len(contents.strip()) > 0: | |
await sse_queue.put(("step", "Generating Response...")) | |
token_count = state['model'].get_num_tokens(contents) | |
if token_count > MAX_TOKENS_ALLOWED: | |
contents = await state['late_chunker'].chunker( | |
text=contents, | |
query=user_query, | |
max_tokens=MAX_TOKENS_ALLOWED | |
) | |
logger.info(f"Number of tokens in the answer: {token_count}") | |
logger.info(f"Number of tokens in the content: {state['model'].get_num_tokens(contents)}") | |
await sse_queue.put(("sources_read", len(current_search_contents))) | |
response = "" | |
is_first_chunk = True | |
async for chunk in state['reasoner'].reason(user_query, contents): | |
if is_first_chunk: | |
await sse_queue.put(("step", f"Thought and searched for {int(end - start)} seconds")) | |
is_first_chunk = False | |
response += chunk | |
await sse_queue.put(("token", chunk)) | |
await sse_queue.put(("final_message", response)) | |
SESSION_STORE["chat_history"].append({"query": user_query, "response": response}) | |
await sse_queue.put(("action", { | |
"name": "sources", | |
"payload": { | |
"search_results": current_search_results, | |
"search_contents": current_search_contents | |
} | |
})) | |
await sse_queue.put(("action", { | |
"name": "evaluate", | |
"payload": {"query": user_query, "contents": [contents], "response": response} | |
})) | |
await sse_queue.put(("complete", "done")) | |
else: | |
await sse_queue.put(("error", "No results found.")) | |
elif cat_lower == "complex advanced decomposition": | |
current_search_results = [] | |
current_search_contents = [] | |
await sse_queue.put(("step", "Thinking...")) | |
start = time.time() | |
main_query_intent = await state['query_processor'].get_query_intent(user_query) | |
sub_queries, _ = await state['query_processor'].decompose_query(user_query, main_query_intent) | |
await sse_queue.put(("step", "Searching...")) | |
async def sub_query_task(sub_query): | |
try: | |
async def sub_sub_query_task(sub_sub_query): | |
optimized_query = await state['search_engine'].generate_optimized_query(sub_sub_query) | |
search_results = await state['search_engine'].search( | |
optimized_query, | |
num_results=10, | |
exclude_filetypes=["pdf"] | |
) | |
filtered_urls = await state['search_engine'].filter_urls( | |
sub_sub_query, | |
category, | |
search_results | |
) | |
current_search_results.extend(filtered_urls) | |
urls = [r.get('link', 'No URL') for r in filtered_urls] | |
search_contents = await state['crawler'].fetch_page_contents( | |
urls, | |
sub_sub_query, | |
state["session_id"], | |
max_attempts=1, | |
timeout=20 | |
) | |
current_search_contents.extend(search_contents) | |
contents = "" | |
if search_contents: | |
for k, c in enumerate(search_contents, 1): | |
if isinstance(c, Exception): | |
logger.info(f"Error fetching content: {c}") | |
elif c: | |
contents += f"Document {k}:\n{c}\n\n" | |
return contents | |
await sse_queue.put(("task", (sub_query, "RUNNING"))) | |
sub_sub_queries, _ = await state['query_processor'].decompose_query(sub_query) | |
tasks = [sub_sub_query_task(sub_sub_query) for sub_sub_query in sub_sub_queries] | |
results = await asyncio.gather(*tasks) | |
if any(result.strip() for result in results): | |
await sse_queue.put(("task", (sub_query, "DONE"))) | |
else: | |
await sse_queue.put(("task", (sub_query, "FAILED"))) | |
return results | |
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError): | |
await sse_queue.put(("task", (sub_query, "FAILED"))) | |
return [] | |
tasks = [sub_query_task(sub_query) for sub_query in sub_queries] | |
results = await asyncio.gather(*tasks) | |
end = time.time() | |
previous_contents = [] | |
for result in results: | |
if result: | |
for content in result: | |
if isinstance(content, str) and len(content.strip()) > 0: | |
previous_contents.append(content) | |
contents = "\n\n".join(previous_contents) | |
unique_results = [] | |
seen = set() | |
for entry in current_search_results: | |
link = entry["link"] | |
if link not in seen: | |
seen.add(link) | |
unique_results.append(entry) | |
current_search_results = unique_results | |
current_search_contents = list(set(current_search_contents)) | |
if len(contents.strip()) > 0: | |
await sse_queue.put(("step", "Generating Response...")) | |
token_count = state['model'].get_num_tokens(contents) | |
if token_count > MAX_TOKENS_ALLOWED: | |
contents = await state['late_chunker'].chunker( | |
text=contents, | |
query=user_query, | |
max_tokens=MAX_TOKENS_ALLOWED | |
) | |
logger.info(f"Number of tokens in the answer: {token_count}") | |
logger.info(f"Number of tokens in the content: {state['model'].get_num_tokens(contents)}") | |
await sse_queue.put(("sources_read", len(current_search_contents))) | |
response = "" | |
is_first_chunk = True | |
async for chunk in state['reasoner'].reason(user_query, contents): | |
if is_first_chunk: | |
await sse_queue.put(("step", f"Thought and searched for {int(end - start)} seconds")) | |
is_first_chunk = False | |
response += chunk | |
await sse_queue.put(("token", chunk)) | |
await sse_queue.put(("final_message", response)) | |
SESSION_STORE["chat_history"].append({"query": user_query, "response": response}) | |
await sse_queue.put(("action", { | |
"name": "sources", | |
"payload": { | |
"search_results": current_search_results, | |
"search_contents": current_search_contents | |
} | |
})) | |
await sse_queue.put(("action", { | |
"name": "evaluate", | |
"payload": {"query": user_query, "contents": [contents], "response": response} | |
})) | |
await sse_queue.put(("complete", "done")) | |
else: | |
await sse_queue.put(("error", "No results found.")) | |
elif cat_lower == "extensive research dynamic structuring": | |
current_search_results = [] | |
current_search_contents = [] | |
match = re.search( | |
r"^This is the previous context of the conversation:\s*.*?\s*Current Query:\s*(.*)$", | |
user_query, | |
flags=re.DOTALL | re.MULTILINE | |
) | |
if match: | |
user_query = match.group(1) | |
await sse_queue.put(("step", "Thinking...")) | |
await asyncio.sleep(0.01) # Sleep for a short time to allow the message to be sent | |
async def on_event_callback(event_type, data): | |
if event_type == "graph_operation": | |
if data["operation_type"] == "creating_new_graph": | |
await sse_queue.put(("step", "Creating New Graph...")) | |
elif data["operation_type"] == "modifying_existing_graph": | |
await sse_queue.put(("step", "Modifying Existing Graph...")) | |
elif data["operation_type"] == "loading_existing_graph": | |
await sse_queue.put(("step", "Loading Existing Graph...")) | |
elif event_type == "sub_query_created": | |
sub_query = data["sub_query"] | |
await sse_queue.put(("task", (sub_query, "RUNNING"))) | |
elif event_type == "search_process_started": | |
await sse_queue.put(("step", "Searching...")) | |
elif event_type == "sub_query_processed": | |
sub_query = data["sub_query"] | |
await sse_queue.put(("task", (sub_query, "DONE"))) | |
elif event_type == "sub_query_failed": | |
sub_query = data["sub_query"] | |
await sse_queue.put(("task", (sub_query, "FAILED"))) | |
elif event_type == "search_results_filtered": | |
current_search_results.extend(data["filtered_urls"]) | |
filtered_urls = data["filtered_urls"] | |
current_search_results.extend(filtered_urls) | |
elif event_type == "search_contents_fetched": | |
current_search_contents.extend(data["contents"]) | |
contents = data["contents"] | |
current_search_contents.extend(contents) | |
state['graph_rag'].set_on_event_callback(on_event_callback) | |
start = time.time() | |
state['graph_rag'].initialize_schema() | |
await state['graph_rag'].process_graph( | |
user_query, | |
similarity_threshold=0.8, | |
relevance_threshold=0.8, | |
max_tokens_allowed=MAX_TOKENS_ALLOWED | |
) | |
end = time.time() | |
unique_results = [] | |
seen = set() | |
for entry in current_search_results: | |
link = entry["link"] | |
if link not in seen: | |
seen.add(link) | |
unique_results.append(entry) | |
current_search_results = unique_results | |
current_search_contents = list(set(current_search_contents)) | |
await sse_queue.put(("step", "Generating Response...")) | |
answer = state['graph_rag'].query_graph(user_query) | |
if answer: | |
token_count = state['model'].get_num_tokens(answer) | |
if token_count > MAX_TOKENS_ALLOWED: | |
answer = await state['late_chunker'].chunker( | |
text=answer, | |
query=user_query, | |
max_tokens=MAX_TOKENS_ALLOWED | |
) | |
logger.info(f"Number of tokens in the answer: {token_count}") | |
logger.info(f"Number of tokens in the content: {state['model'].get_num_tokens(answer)}") | |
await sse_queue.put(("sources_read", len(current_search_contents))) | |
response = "" | |
is_first_chunk = True | |
async for chunk in state['reasoner'].reason(user_query, answer): | |
if is_first_chunk: | |
await sse_queue.put(("step", f"Thought and searched for {int(end - start)} seconds")) | |
is_first_chunk = False | |
response += chunk | |
await sse_queue.put(("token", chunk)) | |
await sse_queue.put(("final_message", response)) | |
SESSION_STORE["chat_history"].append({"query": user_query, "response": response}) | |
await sse_queue.put(("action", { | |
"name": "sources", | |
"payload": {"search_results": current_search_results, "search_contents": current_search_contents}, | |
})) | |
await sse_queue.put(("action", { | |
"name": "graph", | |
"payload": {"query": user_query}, | |
})) | |
await sse_queue.put(("action", { | |
"name": "evaluate", | |
"payload": {"query": user_query, "contents": [answer], "response": response}, | |
})) | |
await sse_queue.put(("complete", "done")) | |
else: | |
await sse_queue.put(("error", "No results found.")) | |
else: | |
await sse_queue.put(("final_message", "I'm not sure how to handle your query.")) | |
except Exception as e: | |
await sse_queue.put(("error", str(e))) | |
# Create a FastAPI app | |
app = FastAPI() | |
# Define allowed origins | |
origins = [ | |
"http://localhost:3000", | |
"http://localhost:7860" | |
"http://localhost:8000", | |
"http://localhost" | |
] | |
# Add the CORS middleware to your FastAPI app | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, # Allows only these origins | |
allow_credentials=True, | |
allow_methods=["*"], # Allows all HTTP methods (GET, POST, etc.) | |
allow_headers=["*"], # Allows all headers | |
) | |
# Serve the React app (the production build) at the root URL. | |
app.mount("/static", StaticFiles(directory="static/static", html=True), name="static") | |
# Catch-all route for frontend paths. | |
async def serve_frontend(full_path: str, request: Request): | |
if full_path.startswith("action") or full_path in ["settings", "message-sse", "stop"]: | |
raise HTTPException(status_code=404, detail="Not Found") | |
index_path = os.path.join("static", "index.html") | |
if not os.path.exists(index_path): | |
raise HTTPException(status_code=500, detail="Frontend build not found") | |
return FileResponse(index_path) | |
# Define the routes for the FastAPI app | |
# Define the route for sources action to display search results | |
def action_sources(payload: Dict[str, Any]) -> Dict[str, Any]: | |
try: | |
search_contents = payload.get("search_contents", []) | |
search_results = payload.get("search_results", []) | |
sources = [] | |
word_limit = 15 # Maximum number of words for the description | |
for result, contents in zip(search_results, search_contents): | |
if contents: | |
title = result.get('title', 'No Title') | |
link = result.get('link', 'No URL') | |
snippet = result.get('snippet', 'No snippet') | |
cleaned = re.sub(r'<[^>]+>|\[\/?.*?\]', '', snippet) | |
words = cleaned.split() | |
if len(words) > word_limit: | |
description = " ".join(words[:word_limit]) + "..." | |
else: | |
description = " ".join(words) | |
source_obj = { | |
"title": title, | |
"link": link, | |
"description": description | |
} | |
sources.append(source_obj) | |
return {"result": sources} | |
except Exception as e: | |
return JSONResponse(content={"error": str(e)}, status_code=500) | |
# Define the route for graph action to display the graph | |
def action_graph(payload: Dict[str, Any]) -> Dict[str, Any]: | |
state = SESSION_STORE | |
try: | |
q = payload.get("query", "") | |
html_str = state['graph_rag'].display_graph(q) | |
return {"result": html_str} | |
except Exception as e: | |
return JSONResponse(content={"error": str(e)}, status_code=500) | |
# Define the route for evaluate action to display evaluation results | |
async def action_evaluate(payload: Dict[str, Any]) -> Dict[str, Any]: | |
try: | |
query = payload.get("query", "") | |
contents = payload.get("contents", []) | |
response = payload.get("response", "") | |
metrics = payload.get("metrics", []) | |
state = SESSION_STORE | |
evaluator = state["evaluator"] | |
result = await evaluator.evaluate_response(query, response, contents, include_metrics=metrics) | |
return {"result": result} | |
except Exception as e: | |
return JSONResponse(content={"error": str(e)}, status_code=500) | |
async def update_settings(data: Dict[str, Any]): | |
from src.helpers.helper import ( | |
prepare_provider_key_updates, | |
prepare_proxy_list_updates, | |
update_env_vars | |
) | |
provider = data.get("Model_Provider", "").strip() | |
model_name = data.get("Model_Name", "").strip() | |
multiple_api_keys = data.get("Model_API_Keys", "").strip() | |
brave_api_key = data.get("Brave_Search_API_Key", "").strip() | |
proxy_list = data.get("Proxy_List", "").strip() | |
neo4j_url = data.get("Neo4j_URL", "").strip() | |
neo4j_username = data.get("Neo4j_Username", "").strip() | |
neo4j_password = data.get("Neo4j_Password", "").strip() | |
model_temperature = str(data.get("Model_Temperature", 0.0)) | |
model_top_p = str(data.get("Model_Top_P", 1.0)) | |
prov_lower = provider.lower() | |
key_updates = prepare_provider_key_updates(prov_lower, multiple_api_keys) | |
env_updates = {} | |
env_updates.update(key_updates) | |
px = prepare_proxy_list_updates(proxy_list) | |
if px: | |
env_updates.update(px) | |
env_updates["BRAVE_API_KEY"] = brave_api_key | |
env_updates["NEO4J_URI"] = neo4j_url | |
env_updates["NEO4J_USER"] = neo4j_username | |
env_updates["NEO4J_PASSWORD"] = neo4j_password | |
env_updates["MODEL_PROVIDER"] = prov_lower | |
env_updates["MODEL_NAME"] = model_name | |
env_updates["MODEL_TEMPERATURE"] = model_temperature | |
env_updates["MODEL_TOP_P"] = model_top_p | |
update_env_vars(env_updates) | |
load_dotenv(override=True) | |
initialize_components() | |
return {"success": True} | |
def init_chat(): | |
if not SESSION_STORE: | |
print("Initializing chat...") | |
SESSION_STORE["settings_saved"] = False | |
SESSION_STORE["session_id"] = None | |
SESSION_STORE["chat_history"] = [] | |
print("Chat initialized!") | |
return {"sucess": True} | |
else: | |
print("Chat already initialized!") | |
return {"success": False} | |
async def sse_message(request: Request, user_message: str): | |
state = SESSION_STORE | |
sse_queue = asyncio.Queue() | |
async def event_generator(): | |
# Build the prompt | |
context = state["chat_history"][-5:] | |
if context: | |
prompt = \ | |
f"""This is the previous context of the conversation: | |
{context} | |
Current Query: | |
{user_message}""" | |
else: | |
prompt = user_message | |
task = asyncio.create_task(process_query(prompt, sse_queue)) | |
state["process_task"] = task | |
while True: | |
if await request.is_disconnected(): | |
task.cancel() | |
break | |
try: | |
event_type, data = await asyncio.wait_for(sse_queue.get(), timeout=5) | |
if event_type == "token": | |
yield f"event: token\ndata: {data}\n\n" | |
elif event_type == "final_message": | |
yield f"event: final_message\ndata: {data}\n\n" | |
elif event_type == "error": | |
yield format_error_sse("error", data) | |
elif event_type == "step": | |
yield f"event: step\ndata: {data}\n\n" | |
elif event_type == "task": | |
subq, status = data | |
j = {"task": subq, "status": status} | |
yield f"event: task\ndata: {json.dumps(j)}\n\n" | |
elif event_type == "sources_read": | |
yield f"event: sources_read\ndata: {data}\n\n" | |
elif event_type == "action": | |
yield f"event: action\ndata: {json.dumps(data)}\n\n" | |
elif event_type == "complete": | |
yield f"event: complete\ndata: {data}\n\n" | |
break | |
else: | |
yield f"event: message\ndata: {data}\n\n" | |
except asyncio.TimeoutError: | |
if task.done(): | |
break | |
continue | |
except asyncio.CancelledError: | |
break | |
if not task.done(): | |
task.cancel() | |
if "process_task" in state: | |
del state["process_task"] | |
return StreamingResponse(event_generator(), media_type="text/event-stream") | |
def stop(): | |
state = SESSION_STORE | |
if "process_task" in state: | |
state["process_task"].cancel() | |
del state["process_task"] | |
return {"message": "Stopped task manually"} |