Eladlev commited on
Commit
1aa017f
·
verified ·
1 Parent(s): 9b60d66

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +247 -153
app.py CHANGED
@@ -1,166 +1,260 @@
1
- import gradio as gr
2
- import io
 
 
 
 
3
  import os
4
- from PIL import Image, ImageDraw
5
- from anthropic import Anthropic
 
 
 
 
 
 
6
  from anthropic.types import TextBlock
7
  from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock
8
- max_tokens = 4096
9
- import base64
10
- model = 'claude-3-5-sonnet-20241022'
11
- system = """<SYSTEM_CAPABILITY>
12
- * You are utilizing a Windows system with internet access.
13
- * The current date is Monday, November 18, 2024.
14
- </SYSTEM_CAPABILITY>"""
15
-
16
- def save_image_or_get_url(image, filename="processed_image.png"):
17
- if not os.path.isdir("static"):
18
- os.mkdir("static")
19
- filepath = os.path.join("static", filename)
20
- image.save(filepath)
21
- return filepath
22
-
23
- def draw_circle_on_image(image, center, radius=30):
24
- """
25
- Draws a circle on the given image using a center point and radius.
26
-
27
- Parameters:
28
- image (PIL.Image): The image to draw on.
29
- center (tuple): A tuple (x, y) representing the center of the circle.
30
- radius (int): The radius of the circle.
31
-
32
- Returns:
33
- PIL.Image: The image with the circle drawn.
34
- """
35
- if not isinstance(center, tuple) or len(center) != 2:
36
- raise ValueError("Center must be a tuple of two values (x, y).")
37
- if not isinstance(radius, (int, float)) or radius <= 0:
38
- raise ValueError("Radius must be a positive number.")
39
-
40
- # Calculate the bounding box for the circle
41
- bbox = [
42
- center[0] - radius, center[1] - radius, # Top-left corner
43
- center[0] + radius, center[1] + radius # Bottom-right corner
44
- ]
45
-
46
- # Create a drawing context
47
- draw = ImageDraw.Draw(image)
48
-
49
- # Draw the circle
50
- draw.ellipse(bbox, outline="red", width=15) # Change outline color and width as needed
51
-
52
- return image
53
-
54
-
55
- def pil_image_to_base64(pil_image):
56
- # Save the PIL image to an in-memory buffer as a file-like object
57
- buffered = io.BytesIO()
58
- pil_image.save(buffered, format="PNG") # Specify format (e.g., PNG, JPEG)
59
- buffered.seek(0) # Rewind the buffer to the beginning
60
-
61
- # Encode the bytes from the buffer to Base64
62
- image_data = base64.b64encode(buffered.getvalue()).decode("utf-8")
63
- return image_data
64
-
65
-
66
-
67
-
68
-
69
-
70
-
71
- # Function to simulate chatbot responses
72
- def chatbot_response(input_text, image, key, chat_history):
73
-
74
- if not key:
75
- return chat_history + [[input_text, "Please enter a valid key."]]
76
- if image is None:
77
- return chat_history + [[input_text, "Please upload an image."]]
78
- api_key =key
79
- client = Anthropic(api_key=api_key)
80
-
81
-
82
-
83
- messages = [{'role': 'user', 'content': [TextBlock(text=f'Look at my screenshot, {input_text}', type='text')]},
84
- {'role': 'assistant', 'content': [BetaTextBlock(
85
- text="I'll help you check your screen, but first I need to take a screenshot to see what you're looking at.",
86
- type='text'), BetaToolUseBlock(id='toolu_01PSTVtavFgmx6ctaiSvacCB',
87
- input={'action': 'screenshot'}, name='computer',
88
- type='tool_use')]}]
89
- image_data = pil_image_to_base64(image)
90
-
91
- tool_res = {'role': 'user', 'content': [{'type': 'tool_result', 'tool_use_id': 'toolu_01PSTVtavFgmx6ctaiSvacCB',
92
- 'is_error': False,
93
- 'content': [{'type': 'image',
94
- 'source': {'type': 'base64', 'media_type': 'image/png',
95
- 'data': image_data}}]}]}
96
- messages.append(tool_res)
97
- params = [{'name': 'computer', 'type': 'computer_20241022', 'display_width_px': 1512, 'display_height_px': 982,
98
- 'display_number': None}, {'type': 'bash_20241022', 'name': 'bash'},
99
- {'name': 'str_replace_editor', 'type': 'text_editor_20241022'}]
100
- raw_response = client.beta.messages.with_raw_response.create(
101
- max_tokens=max_tokens,
102
- messages=messages,
103
- model=model,
104
- system=system,
105
- tools=params,
106
- betas=["computer-use-2024-10-22"],
107
- temperature=0.0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  )
109
- response = raw_response.parse()
110
- scale_x = image.width // 1512
111
- scale_y = image.height // 982
112
- for r in response.content:
113
- if hasattr(r, 'text'):
114
- chat_history = chat_history + [[input_text, r.text]]
115
-
116
- if hasattr(r, 'input') and 'coordinate' in r.input:
117
- coordinate = r.input['coordinate']
118
- new_image = draw_circle_on_image(image, (coordinate[0] * scale_x, coordinate[1] * scale_y))
119
-
120
- # Save the image or encode it as a base64 string if needed
121
- image_url = save_image_or_get_url(
122
- new_image) # Define this function to save or generate the URL for the image
123
-
124
- # Include the image as part of the chat history
125
- image_html = f'<img src="{image_url}" alt="Processed Image" style="max-width: 100%; max-height: 200px;">'
126
- chat_history = chat_history + [[None, (image_url,)]]
127
- return chat_history
128
-
129
- # Read the image and encode it in base64
130
-
131
-
132
 
 
 
 
133
 
134
 
135
- # Simulated response
136
- response = f"Received input: {input_text}\nKey: {key}\nImage uploaded successfully!"
137
- return chat_history + [[input_text, response]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
 
140
- # Create the Gradio interface
141
  with gr.Blocks() as demo:
142
- with gr.Row():
143
- with gr.Column():
144
- image_input = gr.Image(label="Upload Image", type="pil", interactive=True)
145
- with gr.Column():
146
- chatbot = gr.Chatbot(label="Chatbot Interaction", height=400)
147
 
148
- with gr.Row():
149
- user_input = gr.Textbox(label="Type your message here", placeholder="Enter your message...")
150
- key_input = gr.Textbox(label="API Key", placeholder="Enter your key...", type="password")
151
-
152
- # Button to submit
153
- submit_button = gr.Button("Submit")
154
 
155
- # Initialize chat history
156
- chat_history = gr.State(value=[])
157
 
158
- # Set interactions
159
- submit_button.click(
160
- fn=chatbot_response,
161
- inputs=[user_input, image_input, key_input, chat_history],
162
- outputs=[chatbot],
163
- )
164
-
165
- # Launch the app
166
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Entrypoint for Gradio, see https://gradio.app/
3
+ """
4
+
5
+ import asyncio
6
+ import base64
7
  import os
8
+ from datetime import datetime
9
+ from enum import StrEnum
10
+ from functools import partial
11
+ from pathlib import Path
12
+ from typing import cast, Dict
13
+
14
+ import gradio as gr
15
+ from anthropic import APIResponse
16
  from anthropic.types import TextBlock
17
  from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock
18
+ from anthropic.types.tool_use_block import ToolUseBlock
19
+
20
+ from computer_use_demo.loop import (
21
+ PROVIDER_TO_DEFAULT_MODEL_NAME,
22
+ APIProvider,
23
+ sampling_loop,
24
+ sampling_loop_sync,
25
+ )
26
+
27
+ from computer_use_demo.tools import ToolResult
28
+
29
+
30
+ CONFIG_DIR = Path("~/.anthropic").expanduser()
31
+ API_KEY_FILE = CONFIG_DIR / "api_key"
32
+
33
+ WARNING_TEXT = "⚠️ Security Alert: Never provide access to sensitive accounts or data, as malicious web content can hijack Claude's behavior"
34
+
35
+
36
+ class Sender(StrEnum):
37
+ USER = "user"
38
+ BOT = "assistant"
39
+ TOOL = "tool"
40
+
41
+
42
+ def setup_state(state):
43
+ if "messages" not in state:
44
+ state["messages"] = []
45
+ if "api_key" not in state:
46
+ # Try to load API key from file first, then environment
47
+ state["api_key"] = load_from_storage("api_key") or os.getenv("ANTHROPIC_API_KEY", "")
48
+ if not state["api_key"]:
49
+ print("API key not found. Please set it in the environment or storage.")
50
+ if "provider" not in state:
51
+ state["provider"] = os.getenv("API_PROVIDER", "anthropic") or APIProvider.ANTHROPIC
52
+ if "provider_radio" not in state:
53
+ state["provider_radio"] = state["provider"]
54
+ if "model" not in state:
55
+ _reset_model(state)
56
+ if "auth_validated" not in state:
57
+ state["auth_validated"] = False
58
+ if "responses" not in state:
59
+ state["responses"] = {}
60
+ if "tools" not in state:
61
+ state["tools"] = {}
62
+ if "only_n_most_recent_images" not in state:
63
+ state["only_n_most_recent_images"] = 10
64
+ if "custom_system_prompt" not in state:
65
+ state["custom_system_prompt"] = load_from_storage("system_prompt") or ""
66
+ # remove if want to use default system prompt
67
+ state["custom_system_prompt"] += "\n\nNote that you are operating on a Windows machine, so you should use double click to open a desktop application"
68
+ if "hide_images" not in state:
69
+ state["hide_images"] = False
70
+
71
+
72
+ def _reset_model(state):
73
+ state["model"] = PROVIDER_TO_DEFAULT_MODEL_NAME[cast(APIProvider, state["provider"])]
74
+
75
+
76
+ async def main(state):
77
+ """Render loop for Gradio"""
78
+ setup_state(state)
79
+ return "Setup completed"
80
+
81
+
82
+ def validate_auth(provider: APIProvider, api_key: str | None):
83
+ if provider == APIProvider.ANTHROPIC:
84
+ if not api_key:
85
+ return "Enter your Anthropic API key to continue."
86
+ if provider == APIProvider.BEDROCK:
87
+ import boto3
88
+
89
+ if not boto3.Session().get_credentials():
90
+ return "You must have AWS credentials set up to use the Bedrock API."
91
+ if provider == APIProvider.VERTEX:
92
+ import google.auth
93
+ from google.auth.exceptions import DefaultCredentialsError
94
+
95
+ if not os.environ.get("CLOUD_ML_REGION"):
96
+ return "Set the CLOUD_ML_REGION environment variable to use the Vertex API."
97
+ try:
98
+ google.auth.default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
99
+ except DefaultCredentialsError:
100
+ return "Your google cloud credentials are not set up correctly."
101
+
102
+
103
+ def load_from_storage(filename: str) -> str | None:
104
+ """Load data from a file in the storage directory."""
105
+ try:
106
+ file_path = CONFIG_DIR / filename
107
+ if file_path.exists():
108
+ data = file_path.read_text().strip()
109
+ if data:
110
+ return data
111
+ except Exception as e:
112
+ print(f"Debug: Error loading {filename}: {e}")
113
+ return None
114
+
115
+
116
+ def save_to_storage(filename: str, data: str) -> None:
117
+ """Save data to a file in the storage directory."""
118
+ try:
119
+ CONFIG_DIR.mkdir(parents=True, exist_ok=True)
120
+ file_path = CONFIG_DIR / filename
121
+ file_path.write_text(data)
122
+ # Ensure only user can read/write the file
123
+ file_path.chmod(0o600)
124
+ except Exception as e:
125
+ print(f"Debug: Error saving {filename}: {e}")
126
+
127
+
128
+ def _api_response_callback(response: APIResponse[BetaMessage], response_state: dict):
129
+ response_id = datetime.now().isoformat()
130
+ response_state[response_id] = response
131
+
132
+
133
+ def _tool_output_callback(tool_output: ToolResult, tool_id: str, tool_state: dict):
134
+ tool_state[tool_id] = tool_output
135
+
136
+
137
+ def _render_message(sender: Sender, message: str | BetaTextBlock | BetaToolUseBlock | ToolResult, state):
138
+ is_tool_result = not isinstance(message, str) and (
139
+ isinstance(message, ToolResult)
140
+ or message.__class__.__name__ == "ToolResult"
141
+ or message.__class__.__name__ == "CLIResult"
142
+ )
143
+ if not message or (
144
+ is_tool_result
145
+ and state["hide_images"]
146
+ and not hasattr(message, "error")
147
+ and not hasattr(message, "output")
148
+ ):
149
+ return
150
+ if is_tool_result:
151
+ message = cast(ToolResult, message)
152
+ if message.output:
153
+ return message.output
154
+ if message.error:
155
+ return f"Error: {message.error}"
156
+ if message.base64_image and not state["hide_images"]:
157
+ return base64.b64decode(message.base64_image)
158
+ elif isinstance(message, BetaTextBlock) or isinstance(message, TextBlock):
159
+ return message.text
160
+ elif isinstance(message, BetaToolUseBlock) or isinstance(message, ToolUseBlock):
161
+ return f"Tool Use: {message.name}\nInput: {message.input}"
162
+ else:
163
+ return message
164
+ # open new tab, open google sheets inside, then create a new blank spreadsheet
165
+
166
+ def process_input(user_input, state):
167
+ # Ensure the state is properly initialized
168
+ setup_state(state)
169
+
170
+ # Append the user input to the messages in the state
171
+ state["messages"].append(
172
+ {
173
+ "role": Sender.USER,
174
+ "content": [TextBlock(type="text", text=user_input)],
175
+ }
176
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
+ # Run the sampling loop synchronously and yield messages
179
+ for message in sampling_loop(state):
180
+ yield message
181
 
182
 
183
+ def accumulate_messages(*args, **kwargs):
184
+ """
185
+ Wrapper function to accumulate messages from sampling_loop_sync.
186
+ """
187
+ accumulated_messages = []
188
+
189
+ for message in sampling_loop_sync(*args, **kwargs):
190
+ # Check if the message is already in the accumulated messages
191
+ if message not in accumulated_messages:
192
+ accumulated_messages.append(message)
193
+ # Yield the accumulated messages as a list
194
+ yield accumulated_messages
195
+
196
+
197
+ def sampling_loop(state):
198
+ # Ensure the API key is present
199
+ if not state.get("api_key"):
200
+ raise ValueError("API key is missing. Please set it in the environment or storage.")
201
+
202
+ # Call the sampling loop and yield messages
203
+ for message in accumulate_messages(
204
+ system_prompt_suffix=state["custom_system_prompt"],
205
+ model=state["model"],
206
+ provider=state["provider"],
207
+ messages=state["messages"],
208
+ output_callback=partial(_render_message, Sender.BOT, state=state),
209
+ tool_output_callback=partial(_tool_output_callback, tool_state=state["tools"]),
210
+ api_response_callback=partial(_api_response_callback, response_state=state["responses"]),
211
+ api_key=state["api_key"],
212
+ only_n_most_recent_images=state["only_n_most_recent_images"],
213
+ ):
214
+ yield message
215
 
216
 
 
217
  with gr.Blocks() as demo:
218
+ state = gr.State({}) # Use Gradio's state management
 
 
 
 
219
 
220
+ gr.Markdown("# Claude Computer Use Demo")
 
 
 
 
 
221
 
222
+ if not os.getenv("HIDE_WARNING", False):
223
+ gr.Markdown(WARNING_TEXT)
224
 
225
+ with gr.Row():
226
+ provider = gr.Dropdown(
227
+ label="API Provider",
228
+ choices=[option.value for option in APIProvider],
229
+ value="anthropic",
230
+ interactive=True,
231
+ )
232
+ model = gr.Textbox(label="Model", value="claude-3-5-sonnet-20241022")
233
+ api_key = gr.Textbox(
234
+ label="Anthropic API Key",
235
+ type="password",
236
+ value="",
237
+ interactive=True,
238
+ )
239
+ only_n_images = gr.Slider(
240
+ label="Only send N most recent images",
241
+ minimum=0,
242
+ value=10,
243
+ interactive=True,
244
+ )
245
+ custom_prompt = gr.Textbox(
246
+ label="Custom System Prompt Suffix",
247
+ value="",
248
+ interactive=True,
249
+ )
250
+ hide_images = gr.Checkbox(label="Hide screenshots", value=False)
251
+
252
+ api_key.change(fn=lambda key: save_to_storage(API_KEY_FILE, key), inputs=api_key)
253
+ chat_input = gr.Textbox(label="Type a message to send to Claude...")
254
+ # chat_output = gr.Textbox(label="Chat Output", interactive=False)
255
+ chatbot = gr.Chatbot(label="Chatbot History")
256
+
257
+ # Pass state as an input to the function
258
+ chat_input.submit(process_input, [chat_input, state], chatbot)
259
+
260
+ demo.launch(share=True)