Spaces:
Running
Running
import asyncio | |
import subprocess | |
import json | |
from fastapi import FastAPI, WebSocket | |
from fastapi.responses import HTMLResponse | |
from jinja2 import Template | |
from llama_cpp import Llama | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
# Initialize the FastAPI application | |
app = FastAPI() | |
# Define the models and their paths | |
models = { | |
"production": {"file": "DeepSeek-R1-Distill-Llama-8B-Q4_K_L.gguf", "alias": "R1Llama8BQ4L"}, | |
"development": {"file": "/home/ali/Projects/VirtualLabDev/Local/DeepSeek-R1-Distill-Qwen-1.5B-Q2_K.gguf", "alias": "R1Qwen1.5BQ2"}, | |
} | |
# Load the Llama model | |
llm = Llama(model_path=models["development"]["file"], n_ctx=2048) | |
# Define the shell execution tool | |
def execute_shell(arguments): | |
"""Execute a shell command.""" | |
try: | |
args = json.loads(arguments) | |
command = args.get("command", "") | |
if not command: | |
return json.dumps({"error": "No command provided."}) | |
process = subprocess.run(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) | |
return json.dumps({"stdout": process.stdout, "stderr": process.stderr}) | |
except Exception as e: | |
return json.dumps({"error": str(e)}) | |
# Define the tools available to the assistant | |
tools = { | |
"shell": { | |
"description": "Execute shell commands.", | |
"example_input": '{"command": "ls -l"}', | |
"example_output": '{"stdout": "...", "stderr": "..."}', | |
"function": execute_shell, | |
}, | |
} | |
# Generate the dynamic system prompt | |
def generate_system_prompt(tools): | |
""" | |
Dynamically generate the system prompt based on available tools. | |
""" | |
tool_descriptions = [] | |
for tool_name, tool_data in tools.items(): | |
description = tool_data.get("description", "No description available.") | |
example_input = tool_data.get("example_input", "{}") | |
example_output = tool_data.get("example_output", "{}") | |
tool_descriptions.append( | |
f"""- **{tool_name}**: | |
- Description: {description} | |
- Input: {example_input} | |
- Output: {example_output}""" | |
) | |
return """You are an autonomous computational biology researcher with access to the following tools:\n\n""" + "\n\n".join(tool_descriptions) | |
# Create the system prompt | |
system_prompt = generate_system_prompt(tools) | |
# Tool output handler | |
def extract_tool_calls(response_text): | |
"""Parse tool calls from model output.""" | |
if "<|tool▁calls▁begin|>" not in response_text: | |
return [] | |
tool_calls_part = response_text.split("<|tool▁calls▁begin|>")[1] | |
tool_calls_part = tool_calls_part.split("<|tool▁calls▁end|>")[0] | |
tool_calls = tool_calls_part.split("<|tool▁call▁begin|>") | |
parsed_tool_calls = [] | |
for tool_call in tool_calls: | |
tool_call = tool_call.strip() | |
if tool_call: | |
try: | |
tool_type, tool_name_and_args = tool_call.split("<|tool▁sep|>") | |
tool_name, tool_args = tool_name_and_args.split("\n```json\n", 1) | |
tool_args = tool_args.split("\n```")[0] | |
parsed_tool_calls.append({"type": tool_type, "name": tool_name.strip(), "arguments": tool_args.strip()}) | |
except ValueError: | |
logging.warning("Failed to parse tool call: %s", tool_call) | |
return parsed_tool_calls | |
def process_tool_call(tool_call): | |
"""Execute the requested tool and return its output.""" | |
tool_name = tool_call["name"] | |
tool_args = tool_call["arguments"] | |
if tool_name in tools: | |
tool_function = tools[tool_name]["function"] | |
return tool_function(tool_args) | |
else: | |
return json.dumps({"error": f"Tool {tool_name} not found."}) | |
# Chat template for generating prompts | |
CHAT_TEMPLATE = """ | |
{% for message in messages %} | |
{% if message.role == "system" -%} | |
{{ message.content }} | |
{% elif message.role == "assistant" -%} | |
<|Assistant|>{{ message.content }} | |
{% elif message.role == "tool" -%} | |
<|Tool|>{{ message.content }} | |
{% endif %} | |
{% endfor %} | |
""" | |
# Response handler for generating prompts and parsing results | |
async def generate_response(conversation): | |
"""Generate a model response asynchronously.""" | |
template = Template(CHAT_TEMPLATE) | |
prompt = template.render(messages=conversation, bos_token="") | |
logging.info(f"Prompt: {prompt}") | |
for token in llm(prompt, stream=True): | |
yield token["choices"][0]["text"] # Regular generator | |
await asyncio.sleep(0) # Allows async execution | |
# WebSocket for streaming autonomous research interactions | |
async def stream(websocket: WebSocket): | |
"""WebSocket handler to stream AI research process.""" | |
logging.info("WebSocket connection established.") | |
await websocket.accept() | |
await websocket.send_text("🚀 Autonomous computational biology research initiated!") | |
# Initialize the conversation | |
conversation = [{"role": "system", "content": system_prompt}] | |
while True: | |
try: | |
# Stream AI thought process | |
async for response_text in generate_response(conversation): | |
logging.info(f"Response: {response_text}") | |
await websocket.send_text(f"🧠 Model Thinking: {response_text}") | |
# Check for tool calls in response | |
tool_calls = extract_tool_calls(response_text) | |
logging.info(f"Tool calls: {tool_calls}") | |
if tool_calls: | |
for tool_call in tool_calls: | |
# Process each tool call | |
tool_output = process_tool_call(tool_call) | |
await websocket.send_text(f"🔧 Tool Execution: {tool_output}") | |
# Add the tool's output to the conversation | |
conversation.append({"role": "tool", "content": tool_output}) | |
except Exception as e: | |
logging.error(f"Error occurred: {str(e)}") | |
break | |
await websocket.close() | |
# Serve the frontend | |
async def get(): | |
"""Serve the frontend application.""" | |
html_content = """ | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<title>Autonomous Computational Biology Research</title> | |
</head> | |
<body> | |
<h1>AI Agent for Computational Biology Research</h1> | |
<div id="log" style="white-space: pre-line; font-family: monospace;"></div> | |
<script> | |
const ws = new WebSocket("ws://localhost:8000/stream"); | |
const log = document.getElementById("log"); | |
ws.onmessage = (event) => { log.textContent += event.data + "\\n"; }; | |
</script> | |
</body> | |
</html> | |
""" | |
return HTMLResponse(html_content) | |