Spaces:
Sleeping
Sleeping
""" | |
Entrypoint for Gradio, see https://gradio.app/ | |
""" | |
import platform | |
import asyncio | |
import base64 | |
import os | |
from datetime import datetime | |
from enum import StrEnum | |
from functools import partial | |
from pathlib import Path | |
from typing import cast, Dict | |
import gradio as gr | |
from anthropic import APIResponse | |
from anthropic.types import TextBlock | |
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock | |
from anthropic.types.tool_use_block import ToolUseBlock | |
from computer_use_demo.loop import ( | |
PROVIDER_TO_DEFAULT_MODEL_NAME, | |
APIProvider, | |
sampling_loop, | |
sampling_loop_sync, | |
) | |
from computer_use_demo.tools import ToolResult | |
CONFIG_DIR = Path("~/.anthropic").expanduser() | |
API_KEY_FILE = CONFIG_DIR / "api_key" | |
WARNING_TEXT = "⚠️ Security Alert: Never provide access to sensitive accounts or data, as malicious web content can hijack Claude's behavior" | |
class Sender(StrEnum): | |
USER = "user" | |
BOT = "assistant" | |
TOOL = "tool" | |
def setup_state(state): | |
if "messages" not in state: | |
state["messages"] = [] | |
if "api_key" not in state: | |
# Try to load API key from file first, then environment | |
state["api_key"] = load_from_storage("api_key") or os.getenv("ANTHROPIC_API_KEY", "") | |
if not state["api_key"]: | |
print("API key not found. Please set it in the environment or storage.") | |
if "provider" not in state: | |
state["provider"] = os.getenv("API_PROVIDER", "anthropic") or APIProvider.ANTHROPIC | |
if "provider_radio" not in state: | |
state["provider_radio"] = state["provider"] | |
if "model" not in state: | |
_reset_model(state) | |
if "auth_validated" not in state: | |
state["auth_validated"] = False | |
if "responses" not in state: | |
state["responses"] = {} | |
if "tools" not in state: | |
state["tools"] = {} | |
if "only_n_most_recent_images" not in state: | |
state["only_n_most_recent_images"] = 3 # 10 | |
if "custom_system_prompt" not in state: | |
state["custom_system_prompt"] = load_from_storage("system_prompt") or "" | |
# remove if want to use default system prompt | |
device_os_name = "Windows" if platform.platform == "Windows" else "Mac" if platform.platform == "Darwin" else "Linux" | |
state["custom_system_prompt"] += f"\n\nNOTE: you are operating a {device_os_name} machine" | |
if "hide_images" not in state: | |
state["hide_images"] = False | |
def _reset_model(state): | |
state["model"] = PROVIDER_TO_DEFAULT_MODEL_NAME[cast(APIProvider, state["provider"])] | |
async def main(state): | |
"""Render loop for Gradio""" | |
setup_state(state) | |
return "Setup completed" | |
def validate_auth(provider: APIProvider, api_key: str | None): | |
if provider == APIProvider.ANTHROPIC: | |
if not api_key: | |
return "Enter your Anthropic API key to continue." | |
if provider == APIProvider.BEDROCK: | |
import boto3 | |
if not boto3.Session().get_credentials(): | |
return "You must have AWS credentials set up to use the Bedrock API." | |
if provider == APIProvider.VERTEX: | |
import google.auth | |
from google.auth.exceptions import DefaultCredentialsError | |
if not os.environ.get("CLOUD_ML_REGION"): | |
return "Set the CLOUD_ML_REGION environment variable to use the Vertex API." | |
try: | |
google.auth.default(scopes=["https://www.googleapis.com/auth/cloud-platform"]) | |
except DefaultCredentialsError: | |
return "Your google cloud credentials are not set up correctly." | |
def load_from_storage(filename: str) -> str | None: | |
"""Load data from a file in the storage directory.""" | |
try: | |
file_path = CONFIG_DIR / filename | |
if file_path.exists(): | |
data = file_path.read_text().strip() | |
if data: | |
return data | |
except Exception as e: | |
print(f"Debug: Error loading {filename}: {e}") | |
return None | |
def save_to_storage(filename: str, data: str) -> None: | |
"""Save data to a file in the storage directory.""" | |
try: | |
CONFIG_DIR.mkdir(parents=True, exist_ok=True) | |
file_path = CONFIG_DIR / filename | |
file_path.write_text(data) | |
# Ensure only user can read/write the file | |
file_path.chmod(0o600) | |
except Exception as e: | |
print(f"Debug: Error saving {filename}: {e}") | |
def _api_response_callback(response: APIResponse[BetaMessage], response_state: dict): | |
response_id = datetime.now().isoformat() | |
response_state[response_id] = response | |
def _tool_output_callback(tool_output: ToolResult, tool_id: str, tool_state: dict): | |
tool_state[tool_id] = tool_output | |
def _render_message(sender: Sender, message: str | BetaTextBlock | BetaToolUseBlock | ToolResult, state): | |
is_tool_result = not isinstance(message, str) and ( | |
isinstance(message, ToolResult) | |
or message.__class__.__name__ == "ToolResult" | |
or message.__class__.__name__ == "CLIResult" | |
) | |
if not message or ( | |
is_tool_result | |
and state["hide_images"] | |
and not hasattr(message, "error") | |
and not hasattr(message, "output") | |
): | |
return | |
if is_tool_result: | |
message = cast(ToolResult, message) | |
if message.output: | |
return message.output | |
if message.error: | |
return f"Error: {message.error}" | |
if message.base64_image and not state["hide_images"]: | |
return base64.b64decode(message.base64_image) | |
elif isinstance(message, BetaTextBlock) or isinstance(message, TextBlock): | |
return message.text | |
elif isinstance(message, BetaToolUseBlock) or isinstance(message, ToolUseBlock): | |
return f"Tool Use: {message.name}\nInput: {message.input}" | |
else: | |
return message | |
# open new tab, open google sheets inside, then create a new blank spreadsheet | |
def process_input(user_input, state): | |
# Ensure the state is properly initialized | |
setup_state(state) | |
# Append the user input to the messages in the state | |
state["messages"].append( | |
{ | |
"role": Sender.USER, | |
"content": [TextBlock(type="text", text=user_input)], | |
} | |
) | |
# Run the sampling loop synchronously and yield messages | |
for message in sampling_loop(state): | |
yield message | |
def accumulate_messages(*args, **kwargs): | |
""" | |
Wrapper function to accumulate messages from sampling_loop_sync. | |
""" | |
accumulated_messages = [] | |
for message in sampling_loop_sync(*args, **kwargs): | |
# Check if the message is already in the accumulated messages | |
if message not in accumulated_messages: | |
accumulated_messages.append(message) | |
# Yield the accumulated messages as a list | |
yield accumulated_messages | |
def sampling_loop(state): | |
# Ensure the API key is present | |
if not state.get("api_key"): | |
raise ValueError("API key is missing. Please set it in the environment or storage.") | |
# Call the sampling loop and yield messages | |
for message in accumulate_messages( | |
system_prompt_suffix=state["custom_system_prompt"], | |
model=state["model"], | |
provider=state["provider"], | |
messages=state["messages"], | |
output_callback=partial(_render_message, Sender.BOT, state=state), | |
tool_output_callback=partial(_tool_output_callback, tool_state=state["tools"]), | |
api_response_callback=partial(_api_response_callback, response_state=state["responses"]), | |
api_key=state["api_key"], | |
only_n_most_recent_images=state["only_n_most_recent_images"], | |
): | |
yield message | |
with gr.Blocks() as demo: | |
state = gr.State({}) # Use Gradio's state management | |
gr.Markdown("# Claude Computer Use Demo") | |
if not os.getenv("HIDE_WARNING", False): | |
gr.Markdown(WARNING_TEXT) | |
with gr.Row(): | |
provider = gr.Dropdown( | |
label="API Provider", | |
choices=[option.value for option in APIProvider], | |
value="anthropic", | |
interactive=True, | |
) | |
model = gr.Textbox(label="Model", value="claude-3-5-sonnet-20241022") | |
api_key = gr.Textbox( | |
label="Anthropic API Key", | |
type="password", | |
value="", | |
interactive=True, | |
) | |
only_n_images = gr.Slider( | |
label="Only send N most recent images", | |
minimum=0, | |
value=3, # 10 | |
interactive=True, | |
) | |
custom_prompt = gr.Textbox( | |
label="Custom System Prompt Suffix", | |
value="", | |
interactive=True, | |
) | |
hide_images = gr.Checkbox(label="Hide screenshots", value=False) | |
api_key.change(fn=lambda key: save_to_storage(API_KEY_FILE, key), inputs=api_key) | |
chat_input = gr.Textbox(label="Type a message to send to Claude...") | |
# chat_output = gr.Textbox(label="Chat Output", interactive=False) | |
chatbot = gr.Chatbot(label="Chatbot History", autoscroll=True) | |
# Pass state as an input to the function | |
chat_input.submit(process_input, [chat_input, state], chatbot) | |
demo.launch(share=True) | |