Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- .gitattributes +10 -0
- .gitignore +18 -0
- LICENSE +190 -0
- README.md +269 -8
- app.py +633 -0
- assets/Teaser.gif +3 -0
- assets/examples/init_states/amazon.png +3 -0
- assets/examples/init_states/booking.png +3 -0
- assets/examples/init_states/honkai_star_rail.png +3 -0
- assets/examples/init_states/honkai_star_rail_showui.png +3 -0
- assets/examples/init_states/ign.png +3 -0
- assets/examples/init_states/powerpoint.png +3 -0
- assets/examples/init_states/powerpoint_homepage.png +3 -0
- assets/examples/ootb_examples.json +73 -0
- assets/gradio_interface.png +3 -0
- assets/ootb_icon.png +0 -0
- assets/ootb_logo.png +0 -0
- assets/wechat_3.jpg +3 -0
- computer_use_demo/__init__.py +0 -0
- computer_use_demo/executor/anthropic_executor.py +135 -0
- computer_use_demo/executor/showui_executor.py +376 -0
- computer_use_demo/gui_agent/actor/showui_agent.py +178 -0
- computer_use_demo/gui_agent/actor/uitars_agent.py +169 -0
- computer_use_demo/gui_agent/llm_utils/llm_utils.py +109 -0
- computer_use_demo/gui_agent/llm_utils/oai.py +218 -0
- computer_use_demo/gui_agent/llm_utils/qwen.py +108 -0
- computer_use_demo/gui_agent/llm_utils/run_llm.py +44 -0
- computer_use_demo/gui_agent/planner/anthropic_agent.py +206 -0
- computer_use_demo/gui_agent/planner/api_vlm_planner.py +305 -0
- computer_use_demo/gui_agent/planner/local_vlm_planner.py +235 -0
- computer_use_demo/loop.py +276 -0
- computer_use_demo/remote_inference.py +453 -0
- computer_use_demo/tools/__init__.py +16 -0
- computer_use_demo/tools/base.py +69 -0
- computer_use_demo/tools/bash.py +136 -0
- computer_use_demo/tools/collection.py +41 -0
- computer_use_demo/tools/colorful_text.py +27 -0
- computer_use_demo/tools/computer.py +621 -0
- computer_use_demo/tools/edit.py +290 -0
- computer_use_demo/tools/logger.py +21 -0
- computer_use_demo/tools/run.py +42 -0
- computer_use_demo/tools/screen_capture.py +171 -0
- dev-requirements.txt +23 -0
- docs/README_cn.md +172 -0
- install_tools/install_showui-awq-4bit.py +17 -0
- install_tools/install_showui.py +17 -0
- install_tools/install_uitars-2b-sft.py +17 -0
- install_tools/test_ui-tars_server.py +82 -0
.gitattributes
CHANGED
@@ -33,3 +33,13 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/Teaser.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/examples/init_states/amazon.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/examples/init_states/booking.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/examples/init_states/honkai_star_rail.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
assets/examples/init_states/honkai_star_rail_showui.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
assets/examples/init_states/ign.png filter=lfs diff=lfs merge=lfs -text
|
42 |
+
assets/examples/init_states/powerpoint.png filter=lfs diff=lfs merge=lfs -text
|
43 |
+
assets/examples/init_states/powerpoint_homepage.png filter=lfs diff=lfs merge=lfs -text
|
44 |
+
assets/gradio_interface.png filter=lfs diff=lfs merge=lfs -text
|
45 |
+
assets/wechat_3.jpg filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.venv
|
2 |
+
.ruff_cache
|
3 |
+
__pycache__
|
4 |
+
.pytest_cache
|
5 |
+
.cache
|
6 |
+
.ipynb_checkpoints
|
7 |
+
.ipynb
|
8 |
+
.DS_Store
|
9 |
+
/tmp
|
10 |
+
/.gradio
|
11 |
+
/.zed
|
12 |
+
/showui*
|
13 |
+
/ui-tars*
|
14 |
+
/demo
|
15 |
+
/Qwen*
|
16 |
+
/install_tools/install_qwen*
|
17 |
+
/dev_tools*
|
18 |
+
test.ipynb
|
LICENSE
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
Copyright [2024] [Show Lab Computer-Use-OOTB Team]
|
179 |
+
|
180 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
181 |
+
you may not use this file except in compliance with the License.
|
182 |
+
You may obtain a copy of the License at
|
183 |
+
|
184 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
185 |
+
|
186 |
+
Unless required by applicable law or agreed to in writing, software
|
187 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
188 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
189 |
+
See the License for the specific language governing permissions and
|
190 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,12 +1,273 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji: 👀
|
4 |
-
colorFrom: pink
|
5 |
-
colorTo: yellow
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 5.16.2
|
8 |
app_file: app.py
|
9 |
-
|
|
|
10 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: computer_use_ootb
|
|
|
|
|
|
|
|
|
|
|
3 |
app_file: app.py
|
4 |
+
sdk: gradio
|
5 |
+
sdk_version: 5.13.2
|
6 |
---
|
7 |
+
<h2 align="center">
|
8 |
+
<a href="https://computer-use-ootb.github.io">
|
9 |
+
<img src="./assets/ootb_logo.png" alt="Logo" style="display: block; margin: 0 auto; filter: invert(1) brightness(2);">
|
10 |
+
</a>
|
11 |
+
</h2>
|
12 |
+
|
13 |
+
|
14 |
+
<h5 align="center"> If you like our project, please give us a star ⭐ on GitHub for the latest update.</h5>
|
15 |
+
|
16 |
+
<h5 align=center>
|
17 |
+
|
18 |
+
[](https://arxiv.org/abs/2411.10323)
|
19 |
+
[](https://computer-use-ootb.github.io)
|
20 |
+
[](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2Fshowlab%2Fcomputer_use_ootb&count_bg=%2379C83D&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=hits&edge_flat=false)
|
21 |
+
|
22 |
+
|
23 |
+
</h5>
|
24 |
+
|
25 |
+
## <img src="./assets/ootb_icon.png" alt="Star" style="height:25px; vertical-align:middle; filter: invert(1) brightness(2);"> Overview
|
26 |
+
**Computer Use <span style="color:rgb(106, 158, 210)">O</span><span style="color:rgb(111, 163, 82)">O</span><span style="color:rgb(209, 100, 94)">T</span><span style="color:rgb(238, 171, 106)">B</span>**<img src="./assets/ootb_icon.png" alt="Star" style="height:20px; vertical-align:middle; filter: invert(1) brightness(2);"> is an out-of-the-box (OOTB) solution for Desktop GUI Agent, including API-based (**Claude 3.5 Computer Use**) and locally-running models (**<span style="color:rgb(106, 158, 210)">S</span><span style="color:rgb(111, 163, 82)">h</span><span style="color:rgb(209, 100, 94)">o</span><span style="color:rgb(238, 171, 106)">w</span>UI**, **UI-TARS**).
|
27 |
+
|
28 |
+
**No Docker** is required, and it supports both **Windows** and **macOS**. OOTB provides a user-friendly interface based on Gradio. 🎨
|
29 |
+
|
30 |
+
Visit our study on GUI Agent of Claude 3.5 Computer Use [[project page]](https://computer-use-ootb.github.io). 🌐
|
31 |
+
|
32 |
+
## Update
|
33 |
+
- **[2025/02/08]** We've added the support for [**UI-TARS**](https://github.com/bytedance/UI-TARS). Follow [Cloud Deployment](https://github.com/bytedance/UI-TARS?tab=readme-ov-file#cloud-deployment) or [VLLM deployment](https://github.com/bytedance/UI-TARS?tab=readme-ov-file#local-deployment-vllm) to implement UI-TARS and run it locally in OOTB.
|
34 |
+
- **Major Update! [2024/12/04]** **Local Run🔥** is now live! Say hello to [**<span style="color:rgb(106, 158, 210)">S</span><span style="color:rgb(111, 163, 82)">h</span><span style="color:rgb(209, 100, 94)">o</span><span style="color:rgb(238, 171, 106)">w</span>UI**](https://github.com/showlab/ShowUI), an open-source 2B vision-language-action (VLA) model for GUI Agent. Now compatible with `"gpt-4o + ShowUI" (~200x cheaper)`* & `"Qwen2-VL + ShowUI" (~30x cheaper)`* for only few cents for each task💰! <span style="color: grey; font-size: small;">*compared to Claude Computer Use</span>.
|
35 |
+
- **[2024/11/20]** We've added some examples to help you get hands-on experience with Claude 3.5 Computer Use.
|
36 |
+
- **[2024/11/19]** Forget about the single-display limit set by Anthropic - you can now use **multiple displays** 🎉!
|
37 |
+
- **[2024/11/18]** We've released a deep analysis of Claude 3.5 Computer Use: [https://arxiv.org/abs/2411.10323](https://arxiv.org/abs/2411.10323).
|
38 |
+
- **[2024/11/11]** Forget about the low-resolution display limit set by Anthropic — you can now use *any resolution you like* and still keep the **screenshot token cost low** 🎉!
|
39 |
+
- **[2024/11/11]** Now both **Windows** and **macOS** platforms are supported 🎉!
|
40 |
+
- **[2024/10/25]** Now you can **Remotely Control** your computer 💻 through your mobile device 📱 — **No Mobile App Installation** required! Give it a try and have fun 🎉.
|
41 |
+
|
42 |
+
|
43 |
+
## Demo Video
|
44 |
+
|
45 |
+
https://github.com/user-attachments/assets/f50b7611-2350-4712-af9e-3d31e30020ee
|
46 |
+
|
47 |
+
<div style="display: flex; justify-content: space-around;">
|
48 |
+
<a href="https://youtu.be/Ychd-t24HZw" target="_blank" style="margin-right: 10px;">
|
49 |
+
<img src="https://img.youtube.com/vi/Ychd-t24HZw/maxresdefault.jpg" alt="Watch the video" width="48%">
|
50 |
+
</a>
|
51 |
+
<a href="https://youtu.be/cvgPBazxLFM" target="_blank">
|
52 |
+
<img src="https://img.youtube.com/vi/cvgPBazxLFM/maxresdefault.jpg" alt="Watch the video" width="48%">
|
53 |
+
</a>
|
54 |
+
</div>
|
55 |
+
|
56 |
+
|
57 |
+
## 🚀 Getting Started
|
58 |
+
|
59 |
+
### 0. Prerequisites
|
60 |
+
- Instal Miniconda on your system through this [link](https://www.anaconda.com/download?utm_source=anacondadocs&utm_medium=documentation&utm_campaign=download&utm_content=topnavalldocs). (**Python Version: >= 3.12**).
|
61 |
+
- Hardware Requirements (optional, for ShowUI local-run):
|
62 |
+
- **Windows (CUDA-enabled):** A compatible NVIDIA GPU with CUDA support, >=6GB GPU memory
|
63 |
+
- **macOS (Apple Silicon):** M1 chip (or newer), >=16GB unified RAM
|
64 |
+
|
65 |
+
|
66 |
+
### 1. Clone the Repository 📂
|
67 |
+
Open the Conda Terminal. (After installation Of Miniconda, it will appear in the Start menu.)
|
68 |
+
Run the following command on **Conda Terminal**.
|
69 |
+
```bash
|
70 |
+
git clone https://github.com/showlab/computer_use_ootb.git
|
71 |
+
cd computer_use_ootb
|
72 |
+
```
|
73 |
+
|
74 |
+
### 2.1 Install Dependencies 🔧
|
75 |
+
```bash
|
76 |
+
pip install -r dev-requirements.txt
|
77 |
+
```
|
78 |
+
|
79 |
+
### 2.2 (Optional) Get Prepared for **<span style="color:rgb(106, 158, 210)">S</span><span style="color:rgb(111, 163, 82)">h</span><span style="color:rgb(209, 100, 94)">o</span><span style="color:rgb(238, 171, 106)">w</span>UI** Local-Run
|
80 |
+
|
81 |
+
1. Download all files of the ShowUI-2B model via the following command. Ensure the `ShowUI-2B` folder is under the `computer_use_ootb` folder.
|
82 |
+
|
83 |
+
```python
|
84 |
+
python install_tools/install_showui.py
|
85 |
+
```
|
86 |
+
|
87 |
+
2. Make sure to install the correct GPU version of PyTorch (CUDA, MPS, etc.) on your machine. See [install guide and verification](https://pytorch.org/get-started/locally/).
|
88 |
+
|
89 |
+
3. Get API Keys for [GPT-4o](https://platform.openai.com/docs/quickstart) or [Qwen-VL](https://help.aliyun.com/zh/dashscope/developer-reference/acquisition-and-configuration-of-api-key). For mainland China users, Qwen API free trial for first 1 mil tokens is [available](https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-qianwen-vl-plus-api).
|
90 |
+
|
91 |
+
### 2.3 (Optional) Get Prepared for **UI-TARS** Local-Run
|
92 |
+
|
93 |
+
1. Follow [Cloud Deployment](https://github.com/bytedance/UI-TARS?tab=readme-ov-file#cloud-deployment) or [VLLM deployment](https://github.com/bytedance/UI-TARS?tab=readme-ov-file#local-deployment-vllm) guides to deploy your UI-TARS server.
|
94 |
+
|
95 |
+
2. Test your UI-TARS sever with the script `.\install_tools\test_ui-tars_server.py`.
|
96 |
+
|
97 |
+
### 2.4 (Optional) If you want to deploy Qwen model as planner on ssh server
|
98 |
+
1. git clone this project on your ssh server
|
99 |
+
|
100 |
+
2. python computer_use_demo/remote_inference.py
|
101 |
+
### 3. Start the Interface ▶️
|
102 |
+
|
103 |
+
**Start the OOTB interface:**
|
104 |
+
```bash
|
105 |
+
python app.py
|
106 |
+
```
|
107 |
+
If you successfully start the interface, you will see two URLs in the terminal:
|
108 |
+
```bash
|
109 |
+
* Running on local URL: http://127.0.0.1:7860
|
110 |
+
* Running on public URL: https://xxxxxxxxxxxxxxxx.gradio.live (Do not share this link with others, or they will be able to control your computer.)
|
111 |
+
```
|
112 |
+
|
113 |
+
|
114 |
+
> <u>For convenience</u>, we recommend running one or more of the following command to set API keys to the environment variables before starting the interface. Then you don’t need to manually pass the keys each run. On Windows Powershell (via the `set` command if on cmd):
|
115 |
+
> ```bash
|
116 |
+
> $env:ANTHROPIC_API_KEY="sk-xxxxx" (Replace with your own key)
|
117 |
+
> $env:QWEN_API_KEY="sk-xxxxx"
|
118 |
+
> $env:OPENAI_API_KEY="sk-xxxxx"
|
119 |
+
> ```
|
120 |
+
> On macOS/Linux, replace `$env:ANTHROPIC_API_KEY` with `export ANTHROPIC_API_KEY` in the above command.
|
121 |
+
|
122 |
+
|
123 |
+
### 4. Control Your Computer with Any Device can Access the Internet
|
124 |
+
- **Computer to be controlled**: The one installed software.
|
125 |
+
- **Device Send Command**: The one opens the website.
|
126 |
+
|
127 |
+
Open the website at http://localhost:7860/ (if you're controlling the computer itself) or https://xxxxxxxxxxxxxxxxx.gradio.live in your mobile browser for remote control.
|
128 |
+
|
129 |
+
Enter the Anthropic API key (you can obtain it through this [website](https://console.anthropic.com/settings/keys)), then give commands to let the AI perform your tasks.
|
130 |
+
|
131 |
+
### ShowUI Advanced Settings
|
132 |
+
|
133 |
+
We provide a 4-bit quantized ShowUI-2B model for cost-efficient inference (currently **only support CUDA devices**). To download the 4-bit quantized ShowUI-2B model:
|
134 |
+
```
|
135 |
+
python install_tools/install_showui-awq-4bit.py
|
136 |
+
```
|
137 |
+
Then, enable the quantized setting in the 'ShowUI Advanced Settings' dropdown menu.
|
138 |
+
|
139 |
+
Besides, we also provide a slider to quickly adjust the `max_pixel` parameter in the ShowUI model. This controls the visual input size of the model and greatly affects the memory and inference speed.
|
140 |
+
|
141 |
+
## 📊 GUI Agent Model Zoo
|
142 |
+
|
143 |
+
Now, OOTB supports customizing the GUI Agent via the following models:
|
144 |
+
|
145 |
+
- **Unified Model**: Unified planner & actor, can both make the high-level planning and take the low-level control.
|
146 |
+
- **Planner**: General-purpose LLMs, for handling the high-level planning and decision-making.
|
147 |
+
- **Actor**: Vision-language-action models, for handling the low-level control and action command generation.
|
148 |
+
|
149 |
+
|
150 |
+
<div align="center">
|
151 |
+
<b>Supported GUI Agent Models, OOTB</b>
|
152 |
+
|
153 |
+
</div>
|
154 |
+
<table align="center">
|
155 |
+
<tbody>
|
156 |
+
<tr align="center" valign="bottom">
|
157 |
+
<td>
|
158 |
+
<b>[API] Unified Model</b>
|
159 |
+
</td>
|
160 |
+
<td>
|
161 |
+
<b>[API] Planner</b>
|
162 |
+
</td>
|
163 |
+
<td>
|
164 |
+
<b>[Local] Planner</b>
|
165 |
+
</td>
|
166 |
+
<td>
|
167 |
+
<b>[API] Actor</b>
|
168 |
+
</td>
|
169 |
+
<td>
|
170 |
+
<b>[Local] Actor</b>
|
171 |
+
</td>
|
172 |
+
</tr>
|
173 |
+
<tr valign="top">
|
174 |
+
<td>
|
175 |
+
<ul>
|
176 |
+
<li><a href="">Claude 3.5 Sonnet</a></li>
|
177 |
+
</ul>
|
178 |
+
</td>
|
179 |
+
<td>
|
180 |
+
<ul>
|
181 |
+
<li><a href="">GPT-4o</a></li>
|
182 |
+
<li><a href="">Qwen2-VL-Max</a></li>
|
183 |
+
<li><a href="">Qwen2-VL-2B(ssh)</a></li>
|
184 |
+
<li><a href="">Qwen2-VL-7B(ssh)</a></li>
|
185 |
+
<li><a href="">Qwen2.5-VL-7B(ssh)</a></li>
|
186 |
+
<li><a href="">Deepseek V3 (soon)</a></li>
|
187 |
+
</ul>
|
188 |
+
</td>
|
189 |
+
<td>
|
190 |
+
<ul>
|
191 |
+
<li><a href="">Qwen2-VL-2B</a></li>
|
192 |
+
<li><a href="">Qwen2-VL-7B</a></li>
|
193 |
+
</ul>
|
194 |
+
</td>
|
195 |
+
<td>
|
196 |
+
<ul>
|
197 |
+
<li><a href="https://github.com/showlab/ShowUI">ShowUI</a></li>
|
198 |
+
<li><a href="https://huggingface.co/bytedance-research/UI-TARS-7B-DPO">UI-TARS-7B/72B-DPO (soon)</a></li>
|
199 |
+
</ul>
|
200 |
+
</td>
|
201 |
+
<td>
|
202 |
+
<ul>
|
203 |
+
<li><a href="https://github.com/showlab/ShowUI">ShowUI</a></li>
|
204 |
+
<li><a href="https://huggingface.co/bytedance-research/UI-TARS-7B-DPO">UI-TARS-7B/72B-DPO</a></li>
|
205 |
+
</ul>
|
206 |
+
</td>
|
207 |
+
</tr>
|
208 |
+
</td>
|
209 |
+
</table>
|
210 |
+
|
211 |
+
> where [API] models are based on API calling the LLMs that can inference remotely,
|
212 |
+
and [Local] models can use your own device that inferences locally with no API costs.
|
213 |
+
|
214 |
+
|
215 |
+
|
216 |
+
## 🖥️ Supported Systems
|
217 |
+
- **Windows** (Claude ✅, ShowUI ✅)
|
218 |
+
- **macOS** (Claude ✅, ShowUI ✅)
|
219 |
+
|
220 |
+
## 👓 OOTB Iterface
|
221 |
+
<div style="display: flex; align-items: center; gap: 10px;">
|
222 |
+
<figure style="text-align: center;">
|
223 |
+
<img src="./assets/gradio_interface.png" alt="Desktop Interface" style="width: auto; object-fit: contain;">
|
224 |
+
</figure>
|
225 |
+
</div>
|
226 |
+
|
227 |
+
|
228 |
+
## ⚠️ Risks
|
229 |
+
- **Potential Dangerous Operations by the Model**: The models' performance is still limited and may generate unintended or potentially harmful outputs. Recommend continuously monitoring the AI's actions.
|
230 |
+
- **Cost Control**: Each task may cost a few dollars for Claude 3.5 Computer Use.💸
|
231 |
+
|
232 |
+
## 📅 Roadmap
|
233 |
+
- [ ] **Explore available features**
|
234 |
+
- [ ] The Claude API seems to be unstable when solving tasks. We are investigating the reasons: resolutions, types of actions required, os platforms, or planning mechanisms. Welcome any thoughts or comments on it.
|
235 |
+
- [ ] **Interface Design**
|
236 |
+
- [x] **Support for Gradio** ✨
|
237 |
+
- [ ] **Simpler Installation**
|
238 |
+
- [ ] **More Features**... 🚀
|
239 |
+
- [ ] **Platform**
|
240 |
+
- [x] **Windows**
|
241 |
+
- [x] **macOS**
|
242 |
+
- [x] **Mobile** (Send command)
|
243 |
+
- [ ] **Mobile** (Be controlled)
|
244 |
+
- [ ] **Support for More MLLMs**
|
245 |
+
- [x] **Claude 3.5 Sonnet** 🎵
|
246 |
+
- [x] **GPT-4o**
|
247 |
+
- [x] **Qwen2-VL**
|
248 |
+
- [ ] **Local MLLMs**
|
249 |
+
- [ ] ...
|
250 |
+
- [ ] **Improved Prompting Strategy**
|
251 |
+
- [ ] Optimize prompts for cost-efficiency. 💡
|
252 |
+
- [x] **Improved Inference Speed**
|
253 |
+
- [x] Support int4 Quantization.
|
254 |
+
|
255 |
+
## Join Discussion
|
256 |
+
Welcome to discuss with us and continuously improve the user experience of Computer Use - OOTB. Reach us using this [**Discord Channel**](https://discord.gg/vMMJTSew37) or the WeChat QR code below!
|
257 |
+
|
258 |
+
<div style="display: flex; flex-direction: row; justify-content: space-around;">
|
259 |
+
|
260 |
+
<!-- <img src="./assets/wechat_2.jpg" alt="gradio_interface" width="30%"> -->
|
261 |
+
<img src="./assets/wechat_3.jpg" alt="gradio_interface" width="30%">
|
262 |
+
|
263 |
+
</div>
|
264 |
+
|
265 |
+
<div style="height: 30px;"></div>
|
266 |
+
|
267 |
+
<hr>
|
268 |
+
<a href="https://computer-use-ootb.github.io">
|
269 |
+
<img src="./assets/ootb_logo.png" alt="Logo" width="30%" style="display: block; margin: 0 auto; filter: invert(1) brightness(2);">
|
270 |
+
</a>
|
271 |
+
|
272 |
+
|
273 |
|
|
app.py
ADDED
@@ -0,0 +1,633 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Entrypoint for Gradio, see https://gradio.app/
|
3 |
+
"""
|
4 |
+
|
5 |
+
import platform
|
6 |
+
import asyncio
|
7 |
+
import base64
|
8 |
+
import os
|
9 |
+
import io
|
10 |
+
import json
|
11 |
+
from datetime import datetime
|
12 |
+
from enum import StrEnum
|
13 |
+
from functools import partial
|
14 |
+
from pathlib import Path
|
15 |
+
from typing import cast, Dict
|
16 |
+
from PIL import Image
|
17 |
+
|
18 |
+
import gradio as gr
|
19 |
+
from anthropic import APIResponse
|
20 |
+
from anthropic.types import TextBlock
|
21 |
+
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock
|
22 |
+
from anthropic.types.tool_use_block import ToolUseBlock
|
23 |
+
|
24 |
+
from screeninfo import get_monitors
|
25 |
+
from computer_use_demo.tools.logger import logger, truncate_string
|
26 |
+
|
27 |
+
logger.info("Starting the gradio app")
|
28 |
+
|
29 |
+
screens = get_monitors()
|
30 |
+
logger.info(f"Found {len(screens)} screens")
|
31 |
+
|
32 |
+
from computer_use_demo.loop import APIProvider, sampling_loop_sync
|
33 |
+
|
34 |
+
from computer_use_demo.tools import ToolResult
|
35 |
+
from computer_use_demo.tools.computer import get_screen_details
|
36 |
+
SCREEN_NAMES, SELECTED_SCREEN_INDEX = get_screen_details()
|
37 |
+
|
38 |
+
API_KEY_FILE = "./api_keys.json"
|
39 |
+
|
40 |
+
WARNING_TEXT = "⚠️ Security Alert: Do not provide access to sensitive accounts or data, as malicious web content can hijack Agent's behavior. Keep monitor on the Agent's actions."
|
41 |
+
|
42 |
+
|
43 |
+
def setup_state(state):
|
44 |
+
|
45 |
+
if "messages" not in state:
|
46 |
+
state["messages"] = []
|
47 |
+
# -------------------------------
|
48 |
+
if "planner_model" not in state:
|
49 |
+
state["planner_model"] = "gpt-4o" # default
|
50 |
+
if "actor_model" not in state:
|
51 |
+
state["actor_model"] = "ShowUI" # default
|
52 |
+
|
53 |
+
if "planner_provider" not in state:
|
54 |
+
state["planner_provider"] = "openai" # default
|
55 |
+
if "actor_provider" not in state:
|
56 |
+
state["actor_provider"] = "local" # default
|
57 |
+
|
58 |
+
# Fetch API keys from environment variables
|
59 |
+
if "openai_api_key" not in state:
|
60 |
+
state["openai_api_key"] = os.getenv("OPENAI_API_KEY", "")
|
61 |
+
if "anthropic_api_key" not in state:
|
62 |
+
state["anthropic_api_key"] = os.getenv("ANTHROPIC_API_KEY", "")
|
63 |
+
if "qwen_api_key" not in state:
|
64 |
+
state["qwen_api_key"] = os.getenv("QWEN_API_KEY", "")
|
65 |
+
if "ui_tars_url" not in state:
|
66 |
+
state["ui_tars_url"] = ""
|
67 |
+
|
68 |
+
# Set the initial api_key based on the provider
|
69 |
+
if "planner_api_key" not in state:
|
70 |
+
if state["planner_provider"] == "openai":
|
71 |
+
state["planner_api_key"] = state["openai_api_key"]
|
72 |
+
elif state["planner_provider"] == "anthropic":
|
73 |
+
state["planner_api_key"] = state["anthropic_api_key"]
|
74 |
+
elif state["planner_provider"] == "qwen":
|
75 |
+
state["planner_api_key"] = state["qwen_api_key"]
|
76 |
+
else:
|
77 |
+
state["planner_api_key"] = ""
|
78 |
+
|
79 |
+
logger.info(f"loaded initial api_key for {state['planner_provider']}: {state['planner_api_key']}")
|
80 |
+
|
81 |
+
if not state["planner_api_key"]:
|
82 |
+
logger.warning("Planner API key not found. Please set it in the environment or paste in textbox.")
|
83 |
+
|
84 |
+
|
85 |
+
if "selected_screen" not in state:
|
86 |
+
state['selected_screen'] = SELECTED_SCREEN_INDEX if SCREEN_NAMES else 0
|
87 |
+
|
88 |
+
if "auth_validated" not in state:
|
89 |
+
state["auth_validated"] = False
|
90 |
+
if "responses" not in state:
|
91 |
+
state["responses"] = {}
|
92 |
+
if "tools" not in state:
|
93 |
+
state["tools"] = {}
|
94 |
+
if "only_n_most_recent_images" not in state:
|
95 |
+
state["only_n_most_recent_images"] = 10 # 10
|
96 |
+
if "custom_system_prompt" not in state:
|
97 |
+
state["custom_system_prompt"] = ""
|
98 |
+
# remove if want to use default system prompt
|
99 |
+
device_os_name = "Windows" if platform.system() == "Windows" else "Mac" if platform.system() == "Darwin" else "Linux"
|
100 |
+
state["custom_system_prompt"] += f"\n\nNOTE: you are operating a {device_os_name} machine"
|
101 |
+
if "hide_images" not in state:
|
102 |
+
state["hide_images"] = False
|
103 |
+
if 'chatbot_messages' not in state:
|
104 |
+
state['chatbot_messages'] = []
|
105 |
+
|
106 |
+
if "showui_config" not in state:
|
107 |
+
state["showui_config"] = "Default"
|
108 |
+
if "max_pixels" not in state:
|
109 |
+
state["max_pixels"] = 1344
|
110 |
+
if "awq_4bit" not in state:
|
111 |
+
state["awq_4bit"] = False
|
112 |
+
|
113 |
+
|
114 |
+
async def main(state):
|
115 |
+
"""Render loop for Gradio"""
|
116 |
+
setup_state(state)
|
117 |
+
return "Setup completed"
|
118 |
+
|
119 |
+
|
120 |
+
def validate_auth(provider: APIProvider, api_key: str | None):
|
121 |
+
if provider == APIProvider.ANTHROPIC:
|
122 |
+
if not api_key:
|
123 |
+
return "Enter your Anthropic API key to continue."
|
124 |
+
if provider == APIProvider.BEDROCK:
|
125 |
+
import boto3
|
126 |
+
|
127 |
+
if not boto3.Session().get_credentials():
|
128 |
+
return "You must have AWS credentials set up to use the Bedrock API."
|
129 |
+
if provider == APIProvider.VERTEX:
|
130 |
+
import google.auth
|
131 |
+
from google.auth.exceptions import DefaultCredentialsError
|
132 |
+
|
133 |
+
if not os.environ.get("CLOUD_ML_REGION"):
|
134 |
+
return "Set the CLOUD_ML_REGION environment variable to use the Vertex API."
|
135 |
+
try:
|
136 |
+
google.auth.default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
|
137 |
+
except DefaultCredentialsError:
|
138 |
+
return "Your google cloud credentials are not set up correctly."
|
139 |
+
|
140 |
+
|
141 |
+
def _api_response_callback(response: APIResponse[BetaMessage], response_state: dict):
|
142 |
+
response_id = datetime.now().isoformat()
|
143 |
+
response_state[response_id] = response
|
144 |
+
|
145 |
+
|
146 |
+
def _tool_output_callback(tool_output: ToolResult, tool_id: str, tool_state: dict):
|
147 |
+
tool_state[tool_id] = tool_output
|
148 |
+
|
149 |
+
|
150 |
+
def chatbot_output_callback(message, chatbot_state, hide_images=False, sender="bot"):
|
151 |
+
|
152 |
+
def _render_message(message: str | BetaTextBlock | BetaToolUseBlock | ToolResult, hide_images=False):
|
153 |
+
|
154 |
+
logger.info(f"_render_message: {str(message)[:100]}")
|
155 |
+
|
156 |
+
if isinstance(message, str):
|
157 |
+
return message
|
158 |
+
|
159 |
+
is_tool_result = not isinstance(message, str) and (
|
160 |
+
isinstance(message, ToolResult)
|
161 |
+
or message.__class__.__name__ == "ToolResult"
|
162 |
+
or message.__class__.__name__ == "CLIResult"
|
163 |
+
)
|
164 |
+
if not message or (
|
165 |
+
is_tool_result
|
166 |
+
and hide_images
|
167 |
+
and not hasattr(message, "error")
|
168 |
+
and not hasattr(message, "output")
|
169 |
+
): # return None if hide_images is True
|
170 |
+
return
|
171 |
+
# render tool result
|
172 |
+
if is_tool_result:
|
173 |
+
message = cast(ToolResult, message)
|
174 |
+
if message.output:
|
175 |
+
return message.output
|
176 |
+
if message.error:
|
177 |
+
return f"Error: {message.error}"
|
178 |
+
if message.base64_image and not hide_images:
|
179 |
+
# somehow can't display via gr.Image
|
180 |
+
# image_data = base64.b64decode(message.base64_image)
|
181 |
+
# return gr.Image(value=Image.open(io.BytesIO(image_data)))
|
182 |
+
return f'<img src="data:image/png;base64,{message.base64_image}">'
|
183 |
+
|
184 |
+
elif isinstance(message, BetaTextBlock) or isinstance(message, TextBlock):
|
185 |
+
return message.text
|
186 |
+
elif isinstance(message, BetaToolUseBlock) or isinstance(message, ToolUseBlock):
|
187 |
+
return f"Tool Use: {message.name}\nInput: {message.input}"
|
188 |
+
else:
|
189 |
+
return message
|
190 |
+
|
191 |
+
|
192 |
+
# processing Anthropic messages
|
193 |
+
message = _render_message(message, hide_images)
|
194 |
+
|
195 |
+
if sender == "bot":
|
196 |
+
chatbot_state.append((None, message))
|
197 |
+
else:
|
198 |
+
chatbot_state.append((message, None))
|
199 |
+
|
200 |
+
# Create a concise version of the chatbot state for logging
|
201 |
+
concise_state = [(truncate_string(user_msg), truncate_string(bot_msg)) for user_msg, bot_msg in chatbot_state]
|
202 |
+
logger.info(f"chatbot_output_callback chatbot_state: {concise_state} (truncated)")
|
203 |
+
|
204 |
+
|
205 |
+
def process_input(user_input, state):
|
206 |
+
|
207 |
+
setup_state(state)
|
208 |
+
|
209 |
+
# Append the user message to state["messages"]
|
210 |
+
state["messages"].append(
|
211 |
+
{
|
212 |
+
"role": "user",
|
213 |
+
"content": [TextBlock(type="text", text=user_input)],
|
214 |
+
}
|
215 |
+
)
|
216 |
+
|
217 |
+
# Append the user's message to chatbot_messages with None for the assistant's reply
|
218 |
+
state['chatbot_messages'].append((user_input, None))
|
219 |
+
yield state['chatbot_messages'] # Yield to update the chatbot UI with the user's message
|
220 |
+
|
221 |
+
# Run sampling_loop_sync with the chatbot_output_callback
|
222 |
+
for loop_msg in sampling_loop_sync(
|
223 |
+
system_prompt_suffix=state["custom_system_prompt"],
|
224 |
+
planner_model=state["planner_model"],
|
225 |
+
planner_provider=state["planner_provider"],
|
226 |
+
actor_model=state["actor_model"],
|
227 |
+
actor_provider=state["actor_provider"],
|
228 |
+
messages=state["messages"],
|
229 |
+
output_callback=partial(chatbot_output_callback, chatbot_state=state['chatbot_messages'], hide_images=state["hide_images"]),
|
230 |
+
tool_output_callback=partial(_tool_output_callback, tool_state=state["tools"]),
|
231 |
+
api_response_callback=partial(_api_response_callback, response_state=state["responses"]),
|
232 |
+
api_key=state["planner_api_key"],
|
233 |
+
only_n_most_recent_images=state["only_n_most_recent_images"],
|
234 |
+
selected_screen=state['selected_screen'],
|
235 |
+
showui_max_pixels=state['max_pixels'],
|
236 |
+
showui_awq_4bit=state['awq_4bit']
|
237 |
+
):
|
238 |
+
if loop_msg is None:
|
239 |
+
yield state['chatbot_messages']
|
240 |
+
logger.info("End of task. Close the loop.")
|
241 |
+
break
|
242 |
+
|
243 |
+
|
244 |
+
yield state['chatbot_messages'] # Yield the updated chatbot_messages to update the chatbot UI
|
245 |
+
|
246 |
+
|
247 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
248 |
+
|
249 |
+
state = gr.State({}) # Use Gradio's state management
|
250 |
+
setup_state(state.value) # Initialize the state
|
251 |
+
|
252 |
+
# Retrieve screen details
|
253 |
+
gr.Markdown("# Computer Use OOTB")
|
254 |
+
|
255 |
+
if not os.getenv("HIDE_WARNING", False):
|
256 |
+
gr.Markdown(WARNING_TEXT)
|
257 |
+
|
258 |
+
with gr.Accordion("Settings", open=True):
|
259 |
+
with gr.Row():
|
260 |
+
with gr.Column():
|
261 |
+
# --------------------------
|
262 |
+
# Planner
|
263 |
+
planner_model = gr.Dropdown(
|
264 |
+
label="Planner Model",
|
265 |
+
choices=["gpt-4o",
|
266 |
+
"gpt-4o-mini",
|
267 |
+
"qwen2-vl-max",
|
268 |
+
"qwen2-vl-2b (local)",
|
269 |
+
"qwen2-vl-7b (local)",
|
270 |
+
"qwen2-vl-2b (ssh)",
|
271 |
+
"qwen2-vl-7b (ssh)",
|
272 |
+
"qwen2.5-vl-7b (ssh)",
|
273 |
+
"claude-3-5-sonnet-20241022"],
|
274 |
+
value="gpt-4o",
|
275 |
+
interactive=True,
|
276 |
+
)
|
277 |
+
with gr.Column():
|
278 |
+
planner_api_provider = gr.Dropdown(
|
279 |
+
label="API Provider",
|
280 |
+
choices=[option.value for option in APIProvider],
|
281 |
+
value="openai",
|
282 |
+
interactive=False,
|
283 |
+
)
|
284 |
+
with gr.Column():
|
285 |
+
planner_api_key = gr.Textbox(
|
286 |
+
label="Planner API Key",
|
287 |
+
type="password",
|
288 |
+
value=state.value.get("planner_api_key", ""),
|
289 |
+
placeholder="Paste your planner model API key",
|
290 |
+
interactive=True,
|
291 |
+
)
|
292 |
+
|
293 |
+
with gr.Column():
|
294 |
+
actor_model = gr.Dropdown(
|
295 |
+
label="Actor Model",
|
296 |
+
choices=["ShowUI", "UI-TARS"],
|
297 |
+
value="ShowUI",
|
298 |
+
interactive=True,
|
299 |
+
)
|
300 |
+
|
301 |
+
with gr.Column():
|
302 |
+
custom_prompt = gr.Textbox(
|
303 |
+
label="System Prompt Suffix",
|
304 |
+
value="",
|
305 |
+
interactive=True,
|
306 |
+
)
|
307 |
+
with gr.Column():
|
308 |
+
screen_options, primary_index = get_screen_details()
|
309 |
+
SCREEN_NAMES = screen_options
|
310 |
+
SELECTED_SCREEN_INDEX = primary_index
|
311 |
+
screen_selector = gr.Dropdown(
|
312 |
+
label="Select Screen",
|
313 |
+
choices=screen_options,
|
314 |
+
value=screen_options[primary_index] if screen_options else None,
|
315 |
+
interactive=True,
|
316 |
+
)
|
317 |
+
with gr.Column():
|
318 |
+
only_n_images = gr.Slider(
|
319 |
+
label="N most recent screenshots",
|
320 |
+
minimum=0,
|
321 |
+
maximum=10,
|
322 |
+
step=1,
|
323 |
+
value=2,
|
324 |
+
interactive=True,
|
325 |
+
)
|
326 |
+
|
327 |
+
with gr.Accordion("ShowUI Advanced Settings", open=False):
|
328 |
+
|
329 |
+
gr.Markdown("""
|
330 |
+
**Note:** Adjust these settings to fine-tune the resource (**memory** and **infer time**) and performance trade-offs of ShowUI. \\
|
331 |
+
Quantization model requires additional download. Please refer to [Computer Use OOTB - #ShowUI Advanced Settings guide](https://github.com/showlab/computer_use_ootb?tab=readme-ov-file#showui-advanced-settings) for preparation for this feature.
|
332 |
+
""")
|
333 |
+
|
334 |
+
# New configuration for ShowUI
|
335 |
+
with gr.Row():
|
336 |
+
with gr.Column():
|
337 |
+
showui_config = gr.Dropdown(
|
338 |
+
label="ShowUI Preset Configuration",
|
339 |
+
choices=["Default (Maximum)", "Medium", "Minimal", "Custom"],
|
340 |
+
value="Default (Maximum)",
|
341 |
+
interactive=True,
|
342 |
+
)
|
343 |
+
with gr.Column():
|
344 |
+
max_pixels = gr.Slider(
|
345 |
+
label="Max Visual Tokens",
|
346 |
+
minimum=720,
|
347 |
+
maximum=1344,
|
348 |
+
step=16,
|
349 |
+
value=1344,
|
350 |
+
interactive=False,
|
351 |
+
)
|
352 |
+
with gr.Column():
|
353 |
+
awq_4bit = gr.Checkbox(
|
354 |
+
label="Enable AWQ-4bit Model",
|
355 |
+
value=False,
|
356 |
+
interactive=False
|
357 |
+
)
|
358 |
+
|
359 |
+
# Define the merged dictionary with task mappings
|
360 |
+
merged_dict = json.load(open("assets/examples/ootb_examples.json", "r"))
|
361 |
+
|
362 |
+
def update_only_n_images(only_n_images_value, state):
|
363 |
+
state["only_n_most_recent_images"] = only_n_images_value
|
364 |
+
|
365 |
+
# Callback to update the second dropdown based on the first selection
|
366 |
+
def update_second_menu(selected_category):
|
367 |
+
return gr.update(choices=list(merged_dict.get(selected_category, {}).keys()))
|
368 |
+
|
369 |
+
# Callback to update the third dropdown based on the second selection
|
370 |
+
def update_third_menu(selected_category, selected_option):
|
371 |
+
return gr.update(choices=list(merged_dict.get(selected_category, {}).get(selected_option, {}).keys()))
|
372 |
+
|
373 |
+
# Callback to update the textbox based on the third selection
|
374 |
+
def update_textbox(selected_category, selected_option, selected_task):
|
375 |
+
task_data = merged_dict.get(selected_category, {}).get(selected_option, {}).get(selected_task, {})
|
376 |
+
prompt = task_data.get("prompt", "")
|
377 |
+
preview_image = task_data.get("initial_state", "")
|
378 |
+
task_hint = "Task Hint: " + task_data.get("hint", "")
|
379 |
+
return prompt, preview_image, task_hint
|
380 |
+
|
381 |
+
# Function to update the global variable when the dropdown changes
|
382 |
+
def update_selected_screen(selected_screen_name, state):
|
383 |
+
global SCREEN_NAMES
|
384 |
+
global SELECTED_SCREEN_INDEX
|
385 |
+
SELECTED_SCREEN_INDEX = SCREEN_NAMES.index(selected_screen_name)
|
386 |
+
logger.info(f"Selected screen updated to: {SELECTED_SCREEN_INDEX}")
|
387 |
+
state['selected_screen'] = SELECTED_SCREEN_INDEX
|
388 |
+
|
389 |
+
|
390 |
+
def update_planner_model(model_selection, state):
|
391 |
+
state["model"] = model_selection
|
392 |
+
# Update planner_model
|
393 |
+
state["planner_model"] = model_selection
|
394 |
+
logger.info(f"Model updated to: {state['planner_model']}")
|
395 |
+
|
396 |
+
if model_selection == "qwen2-vl-max":
|
397 |
+
provider_choices = ["qwen"]
|
398 |
+
provider_value = "qwen"
|
399 |
+
provider_interactive = False
|
400 |
+
api_key_interactive = True
|
401 |
+
api_key_placeholder = "qwen API key"
|
402 |
+
actor_model_choices = ["ShowUI", "UI-TARS"]
|
403 |
+
actor_model_value = "ShowUI"
|
404 |
+
actor_model_interactive = True
|
405 |
+
api_key_type = "password" # Display API key in password form
|
406 |
+
|
407 |
+
elif model_selection == "qwen2-vl-2b (local)" or model_selection == "qwen2-vl-7b (local)":
|
408 |
+
# Set provider to "openai", make it unchangeable
|
409 |
+
provider_choices = ["local"]
|
410 |
+
provider_value = "local"
|
411 |
+
provider_interactive = False
|
412 |
+
api_key_interactive = False
|
413 |
+
api_key_placeholder = "not required"
|
414 |
+
actor_model_choices = ["ShowUI", "UI-TARS"]
|
415 |
+
actor_model_value = "ShowUI"
|
416 |
+
actor_model_interactive = True
|
417 |
+
api_key_type = "password" # Maintain consistency
|
418 |
+
|
419 |
+
elif "ssh" in model_selection:
|
420 |
+
provider_choices = ["ssh"]
|
421 |
+
provider_value = "ssh"
|
422 |
+
provider_interactive = False
|
423 |
+
api_key_interactive = True
|
424 |
+
api_key_placeholder = "ssh host and port (e.g. localhost:8000)"
|
425 |
+
actor_model_choices = ["ShowUI", "UI-TARS"]
|
426 |
+
actor_model_value = "ShowUI"
|
427 |
+
actor_model_interactive = True
|
428 |
+
api_key_type = "text" # Display SSH connection info in plain text
|
429 |
+
# If SSH connection info already exists, keep it
|
430 |
+
if "planner_api_key" in state and state["planner_api_key"]:
|
431 |
+
state["api_key"] = state["planner_api_key"]
|
432 |
+
else:
|
433 |
+
state["api_key"] = ""
|
434 |
+
|
435 |
+
elif model_selection == "gpt-4o" or model_selection == "gpt-4o-mini":
|
436 |
+
# Set provider to "openai", make it unchangeable
|
437 |
+
provider_choices = ["openai"]
|
438 |
+
provider_value = "openai"
|
439 |
+
provider_interactive = False
|
440 |
+
api_key_interactive = True
|
441 |
+
api_key_type = "password" # Display API key in password form
|
442 |
+
|
443 |
+
api_key_placeholder = "openai API key"
|
444 |
+
actor_model_choices = ["ShowUI", "UI-TARS"]
|
445 |
+
actor_model_value = "ShowUI"
|
446 |
+
actor_model_interactive = True
|
447 |
+
|
448 |
+
elif model_selection == "claude-3-5-sonnet-20241022":
|
449 |
+
# Provider can be any of the current choices except 'openai'
|
450 |
+
provider_choices = [option.value for option in APIProvider if option.value != "openai"]
|
451 |
+
provider_value = "anthropic" # Set default to 'anthropic'
|
452 |
+
provider_interactive = True
|
453 |
+
api_key_interactive = True
|
454 |
+
api_key_placeholder = "claude API key"
|
455 |
+
actor_model_choices = ["claude-3-5-sonnet-20241022"]
|
456 |
+
actor_model_value = "claude-3-5-sonnet-20241022"
|
457 |
+
actor_model_interactive = False
|
458 |
+
api_key_type = "password" # Display API key in password form
|
459 |
+
|
460 |
+
else:
|
461 |
+
raise ValueError(f"Model {model_selection} not supported")
|
462 |
+
|
463 |
+
# Update the provider in state
|
464 |
+
state["planner_api_provider"] = provider_value
|
465 |
+
|
466 |
+
# Update api_key in state based on the provider
|
467 |
+
if provider_value == "openai":
|
468 |
+
state["api_key"] = state.get("openai_api_key", "")
|
469 |
+
elif provider_value == "anthropic":
|
470 |
+
state["api_key"] = state.get("anthropic_api_key", "")
|
471 |
+
elif provider_value == "qwen":
|
472 |
+
state["api_key"] = state.get("qwen_api_key", "")
|
473 |
+
elif provider_value == "local":
|
474 |
+
state["api_key"] = ""
|
475 |
+
# SSH的情况已经在上面处理过了,这里不需要重复处理
|
476 |
+
|
477 |
+
provider_update = gr.update(
|
478 |
+
choices=provider_choices,
|
479 |
+
value=provider_value,
|
480 |
+
interactive=provider_interactive
|
481 |
+
)
|
482 |
+
|
483 |
+
# Update the API Key textbox
|
484 |
+
api_key_update = gr.update(
|
485 |
+
placeholder=api_key_placeholder,
|
486 |
+
value=state["api_key"],
|
487 |
+
interactive=api_key_interactive,
|
488 |
+
type=api_key_type # 添加 type 参数的更新
|
489 |
+
)
|
490 |
+
|
491 |
+
actor_model_update = gr.update(
|
492 |
+
choices=actor_model_choices,
|
493 |
+
value=actor_model_value,
|
494 |
+
interactive=actor_model_interactive
|
495 |
+
)
|
496 |
+
|
497 |
+
logger.info(f"Updated state: model={state['planner_model']}, provider={state['planner_api_provider']}, api_key={state['api_key']}")
|
498 |
+
return provider_update, api_key_update, actor_model_update
|
499 |
+
|
500 |
+
def update_actor_model(actor_model_selection, state):
|
501 |
+
state["actor_model"] = actor_model_selection
|
502 |
+
logger.info(f"Actor model updated to: {state['actor_model']}")
|
503 |
+
|
504 |
+
def update_api_key_placeholder(provider_value, model_selection):
|
505 |
+
if model_selection == "claude-3-5-sonnet-20241022":
|
506 |
+
|
507 |
+
if provider_value == "anthropic":
|
508 |
+
return gr.update(placeholder="anthropic API key")
|
509 |
+
elif provider_value == "bedrock":
|
510 |
+
return gr.update(placeholder="bedrock API key")
|
511 |
+
elif provider_value == "vertex":
|
512 |
+
return gr.update(placeholder="vertex API key")
|
513 |
+
else:
|
514 |
+
return gr.update(placeholder="")
|
515 |
+
elif model_selection == "gpt-4o + ShowUI":
|
516 |
+
return gr.update(placeholder="openai API key")
|
517 |
+
else:
|
518 |
+
return gr.update(placeholder="")
|
519 |
+
|
520 |
+
def update_system_prompt_suffix(system_prompt_suffix, state):
|
521 |
+
state["custom_system_prompt"] = system_prompt_suffix
|
522 |
+
|
523 |
+
# When showui_config changes, we set the max_pixels and awq_4bit accordingly.
|
524 |
+
def handle_showui_config_change(showui_config_val, state):
|
525 |
+
if showui_config_val == "Default (Maximum)":
|
526 |
+
state["max_pixels"] = 1344
|
527 |
+
state["awq_4bit"] = False
|
528 |
+
return (
|
529 |
+
gr.update(value=1344, interactive=False),
|
530 |
+
gr.update(value=False, interactive=False)
|
531 |
+
)
|
532 |
+
elif showui_config_val == "Medium":
|
533 |
+
state["max_pixels"] = 1024
|
534 |
+
state["awq_4bit"] = False
|
535 |
+
return (
|
536 |
+
gr.update(value=1024, interactive=False),
|
537 |
+
gr.update(value=False, interactive=False)
|
538 |
+
)
|
539 |
+
elif showui_config_val == "Minimal":
|
540 |
+
state["max_pixels"] = 1024
|
541 |
+
state["awq_4bit"] = True
|
542 |
+
return (
|
543 |
+
gr.update(value=1024, interactive=False),
|
544 |
+
gr.update(value=True, interactive=False)
|
545 |
+
)
|
546 |
+
elif showui_config_val == "Custom":
|
547 |
+
# Do not overwrite the current user values, just make them interactive
|
548 |
+
return (
|
549 |
+
gr.update(interactive=True),
|
550 |
+
gr.update(interactive=True)
|
551 |
+
)
|
552 |
+
|
553 |
+
def update_api_key(api_key_value, state):
|
554 |
+
"""Handle API key updates"""
|
555 |
+
state["planner_api_key"] = api_key_value
|
556 |
+
if state["planner_provider"] == "ssh":
|
557 |
+
state["api_key"] = api_key_value
|
558 |
+
logger.info(f"API key updated: provider={state['planner_provider']}, api_key={state['api_key']}")
|
559 |
+
|
560 |
+
with gr.Accordion("Quick Start Prompt", open=False): # open=False 表示默认收
|
561 |
+
# Initialize Gradio interface with the dropdowns
|
562 |
+
with gr.Row():
|
563 |
+
# Set initial values
|
564 |
+
initial_category = "Game Play"
|
565 |
+
initial_second_options = list(merged_dict[initial_category].keys())
|
566 |
+
initial_third_options = list(merged_dict[initial_category][initial_second_options[0]].keys())
|
567 |
+
initial_text_value = merged_dict[initial_category][initial_second_options[0]][initial_third_options[0]]
|
568 |
+
|
569 |
+
with gr.Column(scale=2):
|
570 |
+
# First dropdown for Task Category
|
571 |
+
first_menu = gr.Dropdown(
|
572 |
+
choices=list(merged_dict.keys()), label="Task Category", interactive=True, value=initial_category
|
573 |
+
)
|
574 |
+
|
575 |
+
# Second dropdown for Software
|
576 |
+
second_menu = gr.Dropdown(
|
577 |
+
choices=initial_second_options, label="Software", interactive=True, value=initial_second_options[0]
|
578 |
+
)
|
579 |
+
|
580 |
+
# Third dropdown for Task
|
581 |
+
third_menu = gr.Dropdown(
|
582 |
+
choices=initial_third_options, label="Task", interactive=True, value=initial_third_options[0]
|
583 |
+
# choices=["Please select a task"]+initial_third_options, label="Task", interactive=True, value="Please select a task"
|
584 |
+
)
|
585 |
+
|
586 |
+
with gr.Column(scale=1):
|
587 |
+
initial_image_value = "./assets/examples/init_states/honkai_star_rail_showui.png" # default image path
|
588 |
+
image_preview = gr.Image(value=initial_image_value, label="Reference Initial State", height=260-(318.75-280))
|
589 |
+
hintbox = gr.Markdown("Task Hint: Selected options will appear here.")
|
590 |
+
|
591 |
+
# Textbox for displaying the mapped value
|
592 |
+
# textbox = gr.Textbox(value=initial_text_value, label="Action")
|
593 |
+
|
594 |
+
# api_key.change(fn=lambda key: save_to_storage(API_KEY_FILE, key), inputs=api_key)
|
595 |
+
|
596 |
+
with gr.Row():
|
597 |
+
# submit_button = gr.Button("Submit") # Add submit button
|
598 |
+
with gr.Column(scale=8):
|
599 |
+
chat_input = gr.Textbox(show_label=False, placeholder="Type a message to send to Computer Use OOTB...", container=False)
|
600 |
+
with gr.Column(scale=1, min_width=50):
|
601 |
+
submit_button = gr.Button(value="Send", variant="primary")
|
602 |
+
|
603 |
+
chatbot = gr.Chatbot(label="Chatbot History", type="tuples", autoscroll=True, height=580)
|
604 |
+
|
605 |
+
planner_model.change(fn=update_planner_model, inputs=[planner_model, state], outputs=[planner_api_provider, planner_api_key, actor_model])
|
606 |
+
planner_api_provider.change(fn=update_api_key_placeholder, inputs=[planner_api_provider, planner_model], outputs=planner_api_key)
|
607 |
+
actor_model.change(fn=update_actor_model, inputs=[actor_model, state], outputs=None)
|
608 |
+
|
609 |
+
screen_selector.change(fn=update_selected_screen, inputs=[screen_selector, state], outputs=None)
|
610 |
+
only_n_images.change(fn=update_only_n_images, inputs=[only_n_images, state], outputs=None)
|
611 |
+
|
612 |
+
# When showui_config changes, we update max_pixels and awq_4bit automatically.
|
613 |
+
showui_config.change(fn=handle_showui_config_change,
|
614 |
+
inputs=[showui_config, state],
|
615 |
+
outputs=[max_pixels, awq_4bit])
|
616 |
+
|
617 |
+
# Link callbacks to update dropdowns based on selections
|
618 |
+
first_menu.change(fn=update_second_menu, inputs=first_menu, outputs=second_menu)
|
619 |
+
second_menu.change(fn=update_third_menu, inputs=[first_menu, second_menu], outputs=third_menu)
|
620 |
+
third_menu.change(fn=update_textbox, inputs=[first_menu, second_menu, third_menu], outputs=[chat_input, image_preview, hintbox])
|
621 |
+
|
622 |
+
# chat_input.submit(process_input, [chat_input, state], chatbot)
|
623 |
+
submit_button.click(process_input, [chat_input, state], chatbot)
|
624 |
+
|
625 |
+
planner_api_key.change(
|
626 |
+
fn=update_api_key,
|
627 |
+
inputs=[planner_api_key, state],
|
628 |
+
outputs=None
|
629 |
+
)
|
630 |
+
|
631 |
+
demo.launch(share=True,
|
632 |
+
allowed_paths=["./"],
|
633 |
+
server_port=7888) # TODO: allowed_paths
|
assets/Teaser.gif
ADDED
![]() |
Git LFS Details
|
assets/examples/init_states/amazon.png
ADDED
![]() |
Git LFS Details
|
assets/examples/init_states/booking.png
ADDED
![]() |
Git LFS Details
|
assets/examples/init_states/honkai_star_rail.png
ADDED
![]() |
Git LFS Details
|
assets/examples/init_states/honkai_star_rail_showui.png
ADDED
![]() |
Git LFS Details
|
assets/examples/init_states/ign.png
ADDED
![]() |
Git LFS Details
|
assets/examples/init_states/powerpoint.png
ADDED
![]() |
Git LFS Details
|
assets/examples/init_states/powerpoint_homepage.png
ADDED
![]() |
Git LFS Details
|
assets/examples/ootb_examples.json
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"Web Navigation": {
|
3 |
+
"Shopping": {
|
4 |
+
"Search Gift Card": {
|
5 |
+
"hint": "Search for 'You are Amazing' congrats gift card",
|
6 |
+
"prompt": "Search for 'You are Amazing' congrats gift card",
|
7 |
+
"initial_state": ".\\assets\\examples\\init_states\\amazon.png"
|
8 |
+
},
|
9 |
+
"Add Headphones": {
|
10 |
+
"hint": "Add a set of wireless headphones to your cart",
|
11 |
+
"prompt": "Add a set of wireless headphones to your cart",
|
12 |
+
"initial_state": ".\\assets\\examples\\init_states\\amazon.png"
|
13 |
+
}
|
14 |
+
},
|
15 |
+
"Accommodation": {
|
16 |
+
"Find Private Room": {
|
17 |
+
"hint": "Find a private room in New York",
|
18 |
+
"prompt": "Find a private room in New York",
|
19 |
+
"initial_state": ".\\assets\\examples\\init_states\\booking.png"
|
20 |
+
}
|
21 |
+
},
|
22 |
+
"Gaming": {
|
23 |
+
"Walk-through Guide": {
|
24 |
+
"hint": "Find a walk-through guide for the game 'Black Myth: Wukong'",
|
25 |
+
"prompt": "Find a walk-through guide for the game 'Black Myth: Wukong'",
|
26 |
+
"initial_state": ".\\assets\\examples\\init_states\\ign.png"
|
27 |
+
}
|
28 |
+
}
|
29 |
+
},
|
30 |
+
"Productivity": {
|
31 |
+
"Presentations": {
|
32 |
+
"Create Presentation": {
|
33 |
+
"hint": "Create a new presentation and set the title to 'Hail Computer Use OOTB!'",
|
34 |
+
"prompt": "Create a new presentation and edit the title to 'Hail Computer Use OOTB!'",
|
35 |
+
"initial_state": ".\\assets\\examples\\init_states\\powerpoint_homepage.png"
|
36 |
+
},
|
37 |
+
"Duplicate First Slide": {
|
38 |
+
"hint": "Duplicate the first slide in PowerPoint",
|
39 |
+
"prompt": "Duplicate the first slide in PowerPoint",
|
40 |
+
"initial_state": ".\\assets\\examples\\init_states\\powerpoint.png"
|
41 |
+
},
|
42 |
+
"Insert Picture": {
|
43 |
+
"hint": "Insert a picture from my device into the current slide, selecting the first image in the photo browser",
|
44 |
+
"prompt": "Insert a picture from my device into the current slide",
|
45 |
+
"initial_state": ".\\assets\\examples\\init_states\\powerpoint.png"
|
46 |
+
},
|
47 |
+
"Apply Morph Transition": {
|
48 |
+
"hint": "Apply the Morph transition to all slides",
|
49 |
+
"prompt": "Apply the Morph transition to all slides",
|
50 |
+
"initial_state": ".\\assets\\examples\\init_states\\powerpoint.png"
|
51 |
+
}
|
52 |
+
}
|
53 |
+
},
|
54 |
+
"Game Play": {
|
55 |
+
"Honkai: Star Rail": {
|
56 |
+
"Daily Task (ShowUI)": {
|
57 |
+
"hint": "Complete the daily task",
|
58 |
+
"prompt": "1. Escape on the keyboard to open the menu. 2. Click 'Interastral Guide'. 3. Then click 'calyx golden for exp' entry. 4. Then click on the 'Teleport of Buds of MEMORIES'. 5. Press the 'bottom plus + button, the one below'. 6. Then click Challenge 7. Then click Start Challenge. 8. Then click on exit when the battle is completed.",
|
59 |
+
"initial_state": ".\\assets\\examples\\init_states\\honkai_star_rail_showui.png"
|
60 |
+
},
|
61 |
+
"Daily Task (Claude 3.5 Computer Use)": {
|
62 |
+
"hint": "Complete the daily task",
|
63 |
+
"prompt": "You are currently playing Honkai: Star Rail, your objective is to finish a daily game task for me. Press escape on the keyboard to open the menu, then click interastral guide, then click 'calyx golden for exp' entry on the left side of the popped up game window. Only then click on the teleport button on the same line of the first entry named 'buds of MEMORIES' (you need to carefully check the name), then click 'plus +' button 5 times to increase attempts to 6, then click challenge, then click start challenge. Then click the auto-battle button at the right-up corner - carefully count from the right to the left, it should be the second icon, it is near the 'pause' icon, it looks like an 'infinite' symbol. Then click on exit when the battle is completed.",
|
64 |
+
"initial_state": ".\\assets\\examples\\init_states\\honkai_star_rail.png"
|
65 |
+
},
|
66 |
+
"Warp": {
|
67 |
+
"hint": "Perform a warp (gacha pull)",
|
68 |
+
"prompt": "You are currently playing Honkai: Star Rail, your objective is to perform a 10-warp pull for me. Press escape on the keyboard to open the menu, then click warp. It should open the warp page, and the first entry on the left side would be 'Words of Yore', this would be the destination pool. Then click on 'warp x10' to perform a 10-warp pull, then click at the blank space at the right-up corner to reveal the arrow at the right-up corner, then click on the arrow to skip the animation. Always click on the arrow to continue skipping the animation if there is an arrow at the right-up corner. Only when all animations are skipped by clicking on the arrows, the pull summary page will appear and there would be a cross there, click on the cross to finish the pull. Good luck!",
|
69 |
+
"initial_state": ".\\assets\\examples\\init_states\\honkai_star_rail.png"
|
70 |
+
}
|
71 |
+
}
|
72 |
+
}
|
73 |
+
}
|
assets/gradio_interface.png
ADDED
![]() |
Git LFS Details
|
assets/ootb_icon.png
ADDED
![]() |
assets/ootb_logo.png
ADDED
![]() |
assets/wechat_3.jpg
ADDED
![]() |
Git LFS Details
|
computer_use_demo/__init__.py
ADDED
File without changes
|
computer_use_demo/executor/anthropic_executor.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
from typing import Any, Dict, cast
|
3 |
+
from collections.abc import Callable
|
4 |
+
from anthropic.types.beta import (
|
5 |
+
BetaContentBlock,
|
6 |
+
BetaContentBlockParam,
|
7 |
+
BetaImageBlockParam,
|
8 |
+
BetaMessage,
|
9 |
+
BetaMessageParam,
|
10 |
+
BetaTextBlockParam,
|
11 |
+
BetaToolResultBlockParam,
|
12 |
+
)
|
13 |
+
from anthropic.types import TextBlock
|
14 |
+
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock
|
15 |
+
from ..tools import BashTool, ComputerTool, EditTool, ToolCollection, ToolResult
|
16 |
+
|
17 |
+
|
18 |
+
class AnthropicExecutor:
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
output_callback: Callable[[BetaContentBlockParam], None],
|
22 |
+
tool_output_callback: Callable[[Any, str], None],
|
23 |
+
selected_screen: int = 0
|
24 |
+
):
|
25 |
+
self.tool_collection = ToolCollection(
|
26 |
+
ComputerTool(selected_screen=selected_screen),
|
27 |
+
BashTool(),
|
28 |
+
EditTool(),
|
29 |
+
)
|
30 |
+
self.output_callback = output_callback
|
31 |
+
self.tool_output_callback = tool_output_callback
|
32 |
+
|
33 |
+
def __call__(self, response: BetaMessage, messages: list[BetaMessageParam]):
|
34 |
+
new_message = {
|
35 |
+
"role": "assistant",
|
36 |
+
"content": cast(list[BetaContentBlockParam], response.content),
|
37 |
+
}
|
38 |
+
if new_message not in messages:
|
39 |
+
messages.append(new_message)
|
40 |
+
else:
|
41 |
+
print("new_message already in messages, there are duplicates.")
|
42 |
+
|
43 |
+
tool_result_content: list[BetaToolResultBlockParam] = []
|
44 |
+
for content_block in cast(list[BetaContentBlock], response.content):
|
45 |
+
|
46 |
+
self.output_callback(content_block, sender="bot")
|
47 |
+
# Execute the tool
|
48 |
+
if content_block.type == "tool_use":
|
49 |
+
# Run the asynchronous tool execution in a synchronous context
|
50 |
+
result = asyncio.run(self.tool_collection.run(
|
51 |
+
name=content_block.name,
|
52 |
+
tool_input=cast(dict[str, Any], content_block.input),
|
53 |
+
))
|
54 |
+
|
55 |
+
self.output_callback(result, sender="bot")
|
56 |
+
|
57 |
+
tool_result_content.append(
|
58 |
+
_make_api_tool_result(result, content_block.id)
|
59 |
+
)
|
60 |
+
self.tool_output_callback(result, content_block.id)
|
61 |
+
|
62 |
+
# Craft messages based on the content_block
|
63 |
+
# Note: to display the messages in the gradio, you should organize the messages in the following way (user message, bot message)
|
64 |
+
|
65 |
+
display_messages = _message_display_callback(messages)
|
66 |
+
# display_messages = []
|
67 |
+
|
68 |
+
# Send the messages to the gradio
|
69 |
+
for user_msg, bot_msg in display_messages:
|
70 |
+
yield [user_msg, bot_msg], tool_result_content
|
71 |
+
|
72 |
+
if not tool_result_content:
|
73 |
+
return messages
|
74 |
+
|
75 |
+
return tool_result_content
|
76 |
+
|
77 |
+
def _message_display_callback(messages):
|
78 |
+
display_messages = []
|
79 |
+
for msg in messages:
|
80 |
+
try:
|
81 |
+
if isinstance(msg["content"][0], TextBlock):
|
82 |
+
display_messages.append((msg["content"][0].text, None)) # User message
|
83 |
+
elif isinstance(msg["content"][0], BetaTextBlock):
|
84 |
+
display_messages.append((None, msg["content"][0].text)) # Bot message
|
85 |
+
elif isinstance(msg["content"][0], BetaToolUseBlock):
|
86 |
+
display_messages.append((None, f"Tool Use: {msg['content'][0].name}\nInput: {msg['content'][0].input}")) # Bot message
|
87 |
+
elif isinstance(msg["content"][0], Dict) and msg["content"][0]["content"][-1]["type"] == "image":
|
88 |
+
display_messages.append((None, f'<img src="data:image/png;base64,{msg["content"][0]["content"][-1]["source"]["data"]}">')) # Bot message
|
89 |
+
else:
|
90 |
+
print(msg["content"][0])
|
91 |
+
except Exception as e:
|
92 |
+
print("error", e)
|
93 |
+
pass
|
94 |
+
return display_messages
|
95 |
+
|
96 |
+
def _make_api_tool_result(
|
97 |
+
result: ToolResult, tool_use_id: str
|
98 |
+
) -> BetaToolResultBlockParam:
|
99 |
+
"""Convert an agent ToolResult to an API ToolResultBlockParam."""
|
100 |
+
tool_result_content: list[BetaTextBlockParam | BetaImageBlockParam] | str = []
|
101 |
+
is_error = False
|
102 |
+
if result.error:
|
103 |
+
is_error = True
|
104 |
+
tool_result_content = _maybe_prepend_system_tool_result(result, result.error)
|
105 |
+
else:
|
106 |
+
if result.output:
|
107 |
+
tool_result_content.append(
|
108 |
+
{
|
109 |
+
"type": "text",
|
110 |
+
"text": _maybe_prepend_system_tool_result(result, result.output),
|
111 |
+
}
|
112 |
+
)
|
113 |
+
if result.base64_image:
|
114 |
+
tool_result_content.append(
|
115 |
+
{
|
116 |
+
"type": "image",
|
117 |
+
"source": {
|
118 |
+
"type": "base64",
|
119 |
+
"media_type": "image/png",
|
120 |
+
"data": result.base64_image,
|
121 |
+
},
|
122 |
+
}
|
123 |
+
)
|
124 |
+
return {
|
125 |
+
"type": "tool_result",
|
126 |
+
"content": tool_result_content,
|
127 |
+
"tool_use_id": tool_use_id,
|
128 |
+
"is_error": is_error,
|
129 |
+
}
|
130 |
+
|
131 |
+
|
132 |
+
def _maybe_prepend_system_tool_result(result: ToolResult, result_text: str):
|
133 |
+
if result.system:
|
134 |
+
result_text = f"<system>{result.system}</system>\n{result_text}"
|
135 |
+
return result_text
|
computer_use_demo/executor/showui_executor.py
ADDED
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
+
import asyncio
|
3 |
+
from typing import Any, Dict, cast, List, Union
|
4 |
+
from collections.abc import Callable
|
5 |
+
import uuid
|
6 |
+
from anthropic.types.beta import (
|
7 |
+
BetaContentBlock,
|
8 |
+
BetaContentBlockParam,
|
9 |
+
BetaImageBlockParam,
|
10 |
+
BetaMessage,
|
11 |
+
BetaMessageParam,
|
12 |
+
BetaTextBlockParam,
|
13 |
+
BetaToolResultBlockParam,
|
14 |
+
)
|
15 |
+
from anthropic.types import TextBlock
|
16 |
+
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock
|
17 |
+
from computer_use_demo.tools import BashTool, ComputerTool, EditTool, ToolCollection, ToolResult
|
18 |
+
from computer_use_demo.tools.colorful_text import colorful_text_showui, colorful_text_vlm
|
19 |
+
|
20 |
+
|
21 |
+
class ShowUIExecutor:
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
output_callback: Callable[[BetaContentBlockParam], None],
|
25 |
+
tool_output_callback: Callable[[Any, str], None],
|
26 |
+
selected_screen: int = 0
|
27 |
+
):
|
28 |
+
self.output_callback = output_callback
|
29 |
+
self.tool_output_callback = tool_output_callback
|
30 |
+
self.selected_screen = selected_screen
|
31 |
+
self.screen_bbox = self._get_screen_resolution()
|
32 |
+
print("Screen BBox:", self.screen_bbox)
|
33 |
+
|
34 |
+
self.tool_collection = ToolCollection(
|
35 |
+
ComputerTool(selected_screen=selected_screen, is_scaling=False)
|
36 |
+
)
|
37 |
+
|
38 |
+
self.supported_action_type={
|
39 |
+
# "showui_action": "anthropic_tool_action"
|
40 |
+
"CLICK": 'key', # TBD
|
41 |
+
"INPUT": "key",
|
42 |
+
"ENTER": "key", # TBD
|
43 |
+
"ESC": "key",
|
44 |
+
"ESCAPE": "key",
|
45 |
+
"PRESS": "key",
|
46 |
+
}
|
47 |
+
|
48 |
+
def __call__(self, response: str, messages: list[BetaMessageParam]):
|
49 |
+
# response is expected to be :
|
50 |
+
# {'content': "{'action': 'CLICK', 'value': None, 'position': [0.83, 0.15]}, ...", 'role': 'assistant'},
|
51 |
+
|
52 |
+
action_dict = self._format_actor_output(response) # str -> dict
|
53 |
+
|
54 |
+
actions = action_dict["content"]
|
55 |
+
role = action_dict["role"]
|
56 |
+
|
57 |
+
# Parse the actions from showui
|
58 |
+
action_list = self._parse_showui_output(actions)
|
59 |
+
print("Parsed Action List:", action_list)
|
60 |
+
|
61 |
+
tool_result_content = None
|
62 |
+
|
63 |
+
if action_list is not None and len(action_list) > 0:
|
64 |
+
|
65 |
+
for action in action_list: # Execute the tool (adapting the code from anthropic_executor.py)
|
66 |
+
|
67 |
+
tool_result_content: list[BetaToolResultBlockParam] = []
|
68 |
+
|
69 |
+
self.output_callback(f"{colorful_text_showui}:\n{action}", sender="bot")
|
70 |
+
print("Converted Action:", action)
|
71 |
+
|
72 |
+
sim_content_block = BetaToolUseBlock(id=f'toolu_{uuid.uuid4()}',
|
73 |
+
input={'action': action["action"], 'text': action["text"], 'coordinate': action["coordinate"]},
|
74 |
+
name='computer', type='tool_use')
|
75 |
+
|
76 |
+
# update messages
|
77 |
+
new_message = {
|
78 |
+
"role": "assistant",
|
79 |
+
"content": cast(list[BetaContentBlockParam], [sim_content_block]),
|
80 |
+
}
|
81 |
+
if new_message not in messages:
|
82 |
+
messages.append(new_message)
|
83 |
+
|
84 |
+
# Run the asynchronous tool execution in a synchronous context
|
85 |
+
result = self.tool_collection.sync_call(
|
86 |
+
name=sim_content_block.name,
|
87 |
+
tool_input=cast(dict[str, Any], sim_content_block.input),
|
88 |
+
)
|
89 |
+
|
90 |
+
tool_result_content.append(
|
91 |
+
_make_api_tool_result(result, sim_content_block.id)
|
92 |
+
)
|
93 |
+
# print(f"executor: tool_result_content: {tool_result_content}")
|
94 |
+
self.tool_output_callback(result, sim_content_block.id)
|
95 |
+
|
96 |
+
# Craft messages based on the content_block
|
97 |
+
# Note: to display the messages in the gradio, you should organize the messages in the following way (user message, bot message)
|
98 |
+
display_messages = _message_display_callback(messages)
|
99 |
+
# Send the messages to the gradio
|
100 |
+
for user_msg, bot_msg in display_messages:
|
101 |
+
yield [user_msg, bot_msg], tool_result_content
|
102 |
+
|
103 |
+
return tool_result_content
|
104 |
+
|
105 |
+
|
106 |
+
def _format_actor_output(self, action_output: str|dict) -> Dict[str, Any]:
|
107 |
+
if type(action_output) == dict:
|
108 |
+
return action_output
|
109 |
+
else:
|
110 |
+
try:
|
111 |
+
action_output.replace("'", "\"")
|
112 |
+
action_dict = ast.literal_eval(action_output)
|
113 |
+
return action_dict
|
114 |
+
except Exception as e:
|
115 |
+
print(f"Error parsing action output: {e}")
|
116 |
+
return None
|
117 |
+
|
118 |
+
|
119 |
+
def _parse_showui_output(self, output_text: str) -> Union[List[Dict[str, Any]], None]:
|
120 |
+
try:
|
121 |
+
output_text = output_text.strip()
|
122 |
+
|
123 |
+
# process single dictionary
|
124 |
+
if output_text.startswith("{") and output_text.endswith("}"):
|
125 |
+
output_text = f"[{output_text}]"
|
126 |
+
|
127 |
+
# Validate if the output resembles a list of dictionaries
|
128 |
+
if not (output_text.startswith("[") and output_text.endswith("]")):
|
129 |
+
raise ValueError("Output does not look like a valid list or dictionary.")
|
130 |
+
|
131 |
+
print("Output Text:", output_text)
|
132 |
+
|
133 |
+
parsed_output = ast.literal_eval(output_text)
|
134 |
+
|
135 |
+
print("Parsed Output:", parsed_output)
|
136 |
+
|
137 |
+
if isinstance(parsed_output, dict):
|
138 |
+
parsed_output = [parsed_output]
|
139 |
+
elif not isinstance(parsed_output, list):
|
140 |
+
raise ValueError("Parsed output is neither a dictionary nor a list.")
|
141 |
+
|
142 |
+
if not all(isinstance(item, dict) for item in parsed_output):
|
143 |
+
raise ValueError("Not all items in the parsed output are dictionaries.")
|
144 |
+
|
145 |
+
# refine key: value pairs, mapping to the Anthropic's format
|
146 |
+
refined_output = []
|
147 |
+
|
148 |
+
for action_item in parsed_output:
|
149 |
+
|
150 |
+
print("Action Item:", action_item)
|
151 |
+
# sometime showui returns lower case action names
|
152 |
+
action_item["action"] = action_item["action"].upper()
|
153 |
+
|
154 |
+
if action_item["action"] not in self.supported_action_type:
|
155 |
+
raise ValueError(f"Action {action_item['action']} not supported. Check the output from ShowUI: {output_text}")
|
156 |
+
# continue
|
157 |
+
|
158 |
+
elif action_item["action"] == "CLICK": # 1. click -> mouse_move + left_click
|
159 |
+
x, y = action_item["position"]
|
160 |
+
action_item["position"] = (int(x * (self.screen_bbox[2] - self.screen_bbox[0])),
|
161 |
+
int(y * (self.screen_bbox[3] - self.screen_bbox[1])))
|
162 |
+
refined_output.append({"action": "mouse_move", "text": None, "coordinate": tuple(action_item["position"])})
|
163 |
+
refined_output.append({"action": "left_click", "text": None, "coordinate": None})
|
164 |
+
|
165 |
+
elif action_item["action"] == "INPUT": # 2. input -> type
|
166 |
+
refined_output.append({"action": "type", "text": action_item["value"], "coordinate": None})
|
167 |
+
|
168 |
+
elif action_item["action"] == "ENTER": # 3. enter -> key, enter
|
169 |
+
refined_output.append({"action": "key", "text": "Enter", "coordinate": None})
|
170 |
+
|
171 |
+
elif action_item["action"] == "ESC" or action_item["action"] == "ESCAPE": # 4. enter -> key, enter
|
172 |
+
refined_output.append({"action": "key", "text": "Escape", "coordinate": None})
|
173 |
+
|
174 |
+
elif action_item["action"] == "HOVER": # 5. hover -> mouse_move
|
175 |
+
x, y = action_item["position"]
|
176 |
+
action_item["position"] = (int(x * (self.screen_bbox[2] - self.screen_bbox[0])),
|
177 |
+
int(y * (self.screen_bbox[3] - self.screen_bbox[1])))
|
178 |
+
refined_output.append({"action": "mouse_move", "text": None, "coordinate": tuple(action_item["position"])})
|
179 |
+
|
180 |
+
elif action_item["action"] == "SCROLL": # 6. scroll -> key: pagedown
|
181 |
+
if action_item["value"] == "up":
|
182 |
+
refined_output.append({"action": "key", "text": "pageup", "coordinate": None})
|
183 |
+
elif action_item["value"] == "down":
|
184 |
+
refined_output.append({"action": "key", "text": "pagedown", "coordinate": None})
|
185 |
+
else:
|
186 |
+
raise ValueError(f"Scroll direction {action_item['value']} not supported.")
|
187 |
+
|
188 |
+
elif action_item["action"] == "PRESS": # 7. press
|
189 |
+
x, y = action_item["position"]
|
190 |
+
action_item["position"] = (int(x * (self.screen_bbox[2] - self.screen_bbox[0])),
|
191 |
+
int(y * (self.screen_bbox[3] - self.screen_bbox[1])))
|
192 |
+
refined_output.append({"action": "mouse_move", "text": None, "coordinate": tuple(action_item["position"])})
|
193 |
+
refined_output.append({"action": "left_press", "text": None, "coordinate": None})
|
194 |
+
|
195 |
+
return refined_output
|
196 |
+
|
197 |
+
except Exception as e:
|
198 |
+
print(f"Error parsing output: {e}")
|
199 |
+
return None
|
200 |
+
|
201 |
+
|
202 |
+
def _get_screen_resolution(self):
|
203 |
+
from screeninfo import get_monitors
|
204 |
+
import platform
|
205 |
+
if platform.system() == "Darwin":
|
206 |
+
import Quartz # uncomment this line if you are on macOS
|
207 |
+
import subprocess
|
208 |
+
|
209 |
+
# Detect platform
|
210 |
+
system = platform.system()
|
211 |
+
|
212 |
+
if system == "Windows":
|
213 |
+
# Windows: Use screeninfo to get monitor details
|
214 |
+
screens = get_monitors()
|
215 |
+
|
216 |
+
# Sort screens by x position to arrange from left to right
|
217 |
+
sorted_screens = sorted(screens, key=lambda s: s.x)
|
218 |
+
|
219 |
+
if self.selected_screen < 0 or self.selected_screen >= len(screens):
|
220 |
+
raise IndexError("Invalid screen index.")
|
221 |
+
|
222 |
+
screen = sorted_screens[self.selected_screen]
|
223 |
+
bbox = (screen.x, screen.y, screen.x + screen.width, screen.y + screen.height)
|
224 |
+
|
225 |
+
elif system == "Darwin": # macOS
|
226 |
+
# macOS: Use Quartz to get monitor details
|
227 |
+
max_displays = 32 # Maximum number of displays to handle
|
228 |
+
active_displays = Quartz.CGGetActiveDisplayList(max_displays, None, None)[1]
|
229 |
+
|
230 |
+
# Get the display bounds (resolution) for each active display
|
231 |
+
screens = []
|
232 |
+
for display_id in active_displays:
|
233 |
+
bounds = Quartz.CGDisplayBounds(display_id)
|
234 |
+
screens.append({
|
235 |
+
'id': display_id,
|
236 |
+
'x': int(bounds.origin.x),
|
237 |
+
'y': int(bounds.origin.y),
|
238 |
+
'width': int(bounds.size.width),
|
239 |
+
'height': int(bounds.size.height),
|
240 |
+
'is_primary': Quartz.CGDisplayIsMain(display_id) # Check if this is the primary display
|
241 |
+
})
|
242 |
+
|
243 |
+
# Sort screens by x position to arrange from left to right
|
244 |
+
sorted_screens = sorted(screens, key=lambda s: s['x'])
|
245 |
+
|
246 |
+
if self.selected_screen < 0 or self.selected_screen >= len(screens):
|
247 |
+
raise IndexError("Invalid screen index.")
|
248 |
+
|
249 |
+
screen = sorted_screens[self.selected_screen]
|
250 |
+
bbox = (screen['x'], screen['y'], screen['x'] + screen['width'], screen['y'] + screen['height'])
|
251 |
+
|
252 |
+
else: # Linux or other OS
|
253 |
+
cmd = "xrandr | grep ' primary' | awk '{print $4}'"
|
254 |
+
try:
|
255 |
+
output = subprocess.check_output(cmd, shell=True).decode()
|
256 |
+
resolution = output.strip().split()[0]
|
257 |
+
width, height = map(int, resolution.split('x'))
|
258 |
+
bbox = (0, 0, width, height) # Assuming single primary screen for simplicity
|
259 |
+
except subprocess.CalledProcessError:
|
260 |
+
raise RuntimeError("Failed to get screen resolution on Linux.")
|
261 |
+
|
262 |
+
return bbox
|
263 |
+
|
264 |
+
|
265 |
+
|
266 |
+
def _message_display_callback(messages):
|
267 |
+
display_messages = []
|
268 |
+
for msg in messages:
|
269 |
+
try:
|
270 |
+
if isinstance(msg["content"][0], TextBlock):
|
271 |
+
display_messages.append((msg["content"][0].text, None)) # User message
|
272 |
+
elif isinstance(msg["content"][0], BetaTextBlock):
|
273 |
+
display_messages.append((None, msg["content"][0].text)) # Bot message
|
274 |
+
elif isinstance(msg["content"][0], BetaToolUseBlock):
|
275 |
+
display_messages.append((None, f"Tool Use: {msg['content'][0].name}\nInput: {msg['content'][0].input}")) # Bot message
|
276 |
+
elif isinstance(msg["content"][0], Dict) and msg["content"][0]["content"][-1]["type"] == "image":
|
277 |
+
display_messages.append((None, f'<img src="data:image/png;base64,{msg["content"][0]["content"][-1]["source"]["data"]}">')) # Bot message
|
278 |
+
else:
|
279 |
+
pass
|
280 |
+
# print(msg["content"][0])
|
281 |
+
except Exception as e:
|
282 |
+
print("error", e)
|
283 |
+
pass
|
284 |
+
return display_messages
|
285 |
+
|
286 |
+
|
287 |
+
def _make_api_tool_result(
|
288 |
+
result: ToolResult, tool_use_id: str
|
289 |
+
) -> BetaToolResultBlockParam:
|
290 |
+
"""Convert an agent ToolResult to an API ToolResultBlockParam."""
|
291 |
+
tool_result_content: list[BetaTextBlockParam | BetaImageBlockParam] | str = []
|
292 |
+
is_error = False
|
293 |
+
if result.error:
|
294 |
+
is_error = True
|
295 |
+
tool_result_content = _maybe_prepend_system_tool_result(result, result.error)
|
296 |
+
else:
|
297 |
+
if result.output:
|
298 |
+
tool_result_content.append(
|
299 |
+
{
|
300 |
+
"type": "text",
|
301 |
+
"text": _maybe_prepend_system_tool_result(result, result.output),
|
302 |
+
}
|
303 |
+
)
|
304 |
+
if result.base64_image:
|
305 |
+
tool_result_content.append(
|
306 |
+
{
|
307 |
+
"type": "image",
|
308 |
+
"source": {
|
309 |
+
"type": "base64",
|
310 |
+
"media_type": "image/png",
|
311 |
+
"data": result.base64_image,
|
312 |
+
},
|
313 |
+
}
|
314 |
+
)
|
315 |
+
return {
|
316 |
+
"type": "tool_result",
|
317 |
+
"content": tool_result_content,
|
318 |
+
"tool_use_id": tool_use_id,
|
319 |
+
"is_error": is_error,
|
320 |
+
}
|
321 |
+
|
322 |
+
|
323 |
+
def _maybe_prepend_system_tool_result(result: ToolResult, result_text: str):
|
324 |
+
if result.system:
|
325 |
+
result_text = f"<system>{result.system}</system>\n{result_text}"
|
326 |
+
return result_text
|
327 |
+
|
328 |
+
|
329 |
+
|
330 |
+
# Testing main function
|
331 |
+
if __name__ == "__main__":
|
332 |
+
def output_callback(content_block):
|
333 |
+
# print("Output Callback:", content_block)
|
334 |
+
pass
|
335 |
+
|
336 |
+
def tool_output_callback(result, action):
|
337 |
+
print("[showui_executor] Tool Output Callback:", result, action)
|
338 |
+
pass
|
339 |
+
|
340 |
+
# Instantiate the executor
|
341 |
+
executor = ShowUIExecutor(
|
342 |
+
output_callback=output_callback,
|
343 |
+
tool_output_callback=tool_output_callback,
|
344 |
+
selected_screen=0
|
345 |
+
)
|
346 |
+
|
347 |
+
# test inputs
|
348 |
+
response_content = "{'content': \"{'action': 'CLICK', 'value': None, 'position': [0.49, 0.18]}\", 'role': 'assistant'}"
|
349 |
+
# response_content = {'content': "{'action': 'CLICK', 'value': None, 'position': [0.49, 0.39]}", 'role': 'assistant'}
|
350 |
+
# response_content = "{'content': \"{'action': 'CLICK', 'value': None, 'position': [0.49, 0.42]}, {'action': 'INPUT', 'value': 'weather for New York city', 'position': [0.49, 0.42]}, {'action': 'ENTER', 'value': None, 'position': None}\", 'role': 'assistant'}"
|
351 |
+
|
352 |
+
# Initialize messages
|
353 |
+
messages = []
|
354 |
+
|
355 |
+
# Call the executor
|
356 |
+
print("Testing ShowUIExecutor with response content:", response_content)
|
357 |
+
for message, tool_result_content in executor(response_content, messages):
|
358 |
+
print("Message:", message)
|
359 |
+
print("Tool Result Content:", tool_result_content)
|
360 |
+
|
361 |
+
# Display final messages
|
362 |
+
print("\nFinal messages:")
|
363 |
+
for msg in messages:
|
364 |
+
print(msg)
|
365 |
+
|
366 |
+
|
367 |
+
|
368 |
+
[
|
369 |
+
{'role': 'user', 'content': ['open a new tab and go to amazon.com', 'tmp/outputs/screenshot_b4a1b7e60a5c47359bedbd8707573966.png']},
|
370 |
+
{'role': 'assistant', 'content': ["History Action: {'action': 'mouse_move', 'text': None, 'coordinate': (1216, 88)}"]},
|
371 |
+
{'role': 'assistant', 'content': ["History Action: {'action': 'left_click', 'text': None, 'coordinate': None}"]},
|
372 |
+
{'content': [
|
373 |
+
{'type': 'tool_result', 'content': [{'type': 'text', 'text': 'Moved mouse to (1216, 88)'}], 'tool_use_id': 'toolu_ae4f2886-366c-4789-9fa6-ec13461cef12', 'is_error': False},
|
374 |
+
{'type': 'tool_result', 'content': [{'type': 'text', 'text': 'Performed left_click'}], 'tool_use_id': 'toolu_a7377954-e1b7-4746-9757-b2eb4dcddc82', 'is_error': False}
|
375 |
+
], 'role': 'user'}
|
376 |
+
]
|
computer_use_demo/gui_agent/actor/showui_agent.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import ast
|
3 |
+
import base64
|
4 |
+
from io import BytesIO
|
5 |
+
from pathlib import Path
|
6 |
+
from uuid import uuid4
|
7 |
+
|
8 |
+
import pyautogui
|
9 |
+
import requests
|
10 |
+
import torch
|
11 |
+
from PIL import Image, ImageDraw
|
12 |
+
from qwen_vl_utils import process_vision_info
|
13 |
+
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
14 |
+
|
15 |
+
from computer_use_demo.gui_agent.llm_utils.oai import encode_image
|
16 |
+
from computer_use_demo.tools.colorful_text import colorful_text_showui, colorful_text_vlm
|
17 |
+
from computer_use_demo.tools.screen_capture import get_screenshot
|
18 |
+
|
19 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
20 |
+
|
21 |
+
|
22 |
+
class ShowUIActor:
|
23 |
+
_NAV_SYSTEM = """
|
24 |
+
You are an assistant trained to navigate the {_APP} screen.
|
25 |
+
Given a task instruction, a screen observation, and an action history sequence,
|
26 |
+
output the next action and wait for the next observation.
|
27 |
+
Here is the action space:
|
28 |
+
{_ACTION_SPACE}
|
29 |
+
"""
|
30 |
+
|
31 |
+
_NAV_FORMAT = """
|
32 |
+
Format the action as a dictionary with the following keys:
|
33 |
+
{'action': 'ACTION_TYPE', 'value': 'element', 'position': [x,y]}
|
34 |
+
|
35 |
+
If value or position is not applicable, set it as None.
|
36 |
+
Position might be [[x1,y1], [x2,y2]] if the action requires a start and end position.
|
37 |
+
Position represents the relative coordinates on the screenshot and should be scaled to a range of 0-1.
|
38 |
+
"""
|
39 |
+
|
40 |
+
action_map = {
|
41 |
+
'desktop': """
|
42 |
+
1. CLICK: Click on an element, value is not applicable and the position [x,y] is required.
|
43 |
+
2. INPUT: Type a string into an element, value is a string to type and the position [x,y] is required.
|
44 |
+
3. HOVER: Hover on an element, value is not applicable and the position [x,y] is required.
|
45 |
+
4. ENTER: Enter operation, value and position are not applicable.
|
46 |
+
5. SCROLL: Scroll the screen, value is the direction to scroll and the position is not applicable.
|
47 |
+
6. ESC: ESCAPE operation, value and position are not applicable.
|
48 |
+
7. PRESS: Long click on an element, value is not applicable and the position [x,y] is required.
|
49 |
+
""",
|
50 |
+
'phone': """
|
51 |
+
1. INPUT: Type a string into an element, value is not applicable and the position [x,y] is required.
|
52 |
+
2. SWIPE: Swipe the screen, value is not applicable and the position [[x1,y1], [x2,y2]] is the start and end position of the swipe operation.
|
53 |
+
3. TAP: Tap on an element, value is not applicable and the position [x,y] is required.
|
54 |
+
4. ANSWER: Answer the question, value is the status (e.g., 'task complete') and the position is not applicable.
|
55 |
+
5. ENTER: Enter operation, value and position are not applicable.
|
56 |
+
"""
|
57 |
+
}
|
58 |
+
|
59 |
+
def __init__(self, model_path, output_callback, device=torch.device("cpu"), split='desktop', selected_screen=0,
|
60 |
+
max_pixels=1344, awq_4bit=False):
|
61 |
+
self.device = device
|
62 |
+
self.split = split
|
63 |
+
self.selected_screen = selected_screen
|
64 |
+
self.output_callback = output_callback
|
65 |
+
|
66 |
+
if not model_path or not os.path.exists(model_path) or not os.listdir(model_path):
|
67 |
+
if awq_4bit:
|
68 |
+
model_path = "showlab/ShowUI-2B-AWQ-4bit"
|
69 |
+
else:
|
70 |
+
model_path = "showlab/ShowUI-2B"
|
71 |
+
|
72 |
+
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
|
73 |
+
model_path,
|
74 |
+
torch_dtype=torch.bfloat16,
|
75 |
+
device_map="cpu"
|
76 |
+
).to(self.device)
|
77 |
+
self.model.eval()
|
78 |
+
|
79 |
+
self.min_pixels = 256 * 28 * 28
|
80 |
+
self.max_pixels = max_pixels * 28 * 28
|
81 |
+
# self.max_pixels = 1344 * 28 * 28
|
82 |
+
|
83 |
+
self.processor = AutoProcessor.from_pretrained(
|
84 |
+
"Qwen/Qwen2-VL-2B-Instruct",
|
85 |
+
# "./Qwen2-VL-2B-Instruct",
|
86 |
+
min_pixels=self.min_pixels,
|
87 |
+
max_pixels=self.max_pixels
|
88 |
+
)
|
89 |
+
self.system_prompt = self._NAV_SYSTEM.format(
|
90 |
+
_APP=split,
|
91 |
+
_ACTION_SPACE=self.action_map[split]
|
92 |
+
)
|
93 |
+
self.action_history = '' # Initialize action history
|
94 |
+
|
95 |
+
def __call__(self, messages):
|
96 |
+
|
97 |
+
task = messages
|
98 |
+
|
99 |
+
# screenshot
|
100 |
+
screenshot, screenshot_path = get_screenshot(selected_screen=self.selected_screen, resize=True, target_width=1920, target_height=1080)
|
101 |
+
screenshot_path = str(screenshot_path)
|
102 |
+
image_base64 = encode_image(screenshot_path)
|
103 |
+
self.output_callback(f'Screenshot for {colorful_text_showui}:\n<img src="data:image/png;base64,{image_base64}">', sender="bot")
|
104 |
+
|
105 |
+
# Use system prompt, task, and action history to build the messages
|
106 |
+
messages_for_processor = [
|
107 |
+
{
|
108 |
+
"role": "user",
|
109 |
+
"content": [
|
110 |
+
{"type": "text", "text": self.system_prompt},
|
111 |
+
{"type": "image", "image": screenshot_path, "min_pixels": self.min_pixels, "max_pixels": self.max_pixels},
|
112 |
+
{"type": "text", "text": f"Task: {task}"}
|
113 |
+
],
|
114 |
+
}
|
115 |
+
]
|
116 |
+
|
117 |
+
text = self.processor.apply_chat_template(
|
118 |
+
messages_for_processor, tokenize=False, add_generation_prompt=True,
|
119 |
+
)
|
120 |
+
image_inputs, video_inputs = process_vision_info(messages_for_processor)
|
121 |
+
inputs = self.processor(
|
122 |
+
text=[text],
|
123 |
+
images=image_inputs,
|
124 |
+
videos=video_inputs,
|
125 |
+
padding=True,
|
126 |
+
return_tensors="pt",
|
127 |
+
)
|
128 |
+
inputs = inputs.to(self.device)
|
129 |
+
|
130 |
+
with torch.no_grad():
|
131 |
+
generated_ids = self.model.generate(**inputs, max_new_tokens=128)
|
132 |
+
|
133 |
+
generated_ids_trimmed = [
|
134 |
+
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
135 |
+
]
|
136 |
+
output_text = self.processor.batch_decode(
|
137 |
+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
138 |
+
)[0]
|
139 |
+
|
140 |
+
# Update action history
|
141 |
+
self.action_history += output_text + '\n'
|
142 |
+
|
143 |
+
# Return response in expected format
|
144 |
+
response = {'content': output_text, 'role': 'assistant'}
|
145 |
+
return response
|
146 |
+
|
147 |
+
|
148 |
+
def parse_showui_output(self, output_text):
|
149 |
+
try:
|
150 |
+
# Ensure the output is stripped of any extra spaces
|
151 |
+
output_text = output_text.strip()
|
152 |
+
|
153 |
+
# Wrap the input in brackets if it looks like a single dictionary
|
154 |
+
if output_text.startswith("{") and output_text.endswith("}"):
|
155 |
+
output_text = f"[{output_text}]"
|
156 |
+
|
157 |
+
# Validate if the output resembles a list of dictionaries
|
158 |
+
if not (output_text.startswith("[") and output_text.endswith("]")):
|
159 |
+
raise ValueError("Output does not look like a valid list or dictionary.")
|
160 |
+
|
161 |
+
# Parse the output using ast.literal_eval
|
162 |
+
parsed_output = ast.literal_eval(output_text)
|
163 |
+
|
164 |
+
# Ensure the result is a list
|
165 |
+
if isinstance(parsed_output, dict):
|
166 |
+
parsed_output = [parsed_output]
|
167 |
+
elif not isinstance(parsed_output, list):
|
168 |
+
raise ValueError("Parsed output is neither a dictionary nor a list.")
|
169 |
+
|
170 |
+
# Ensure all elements in the list are dictionaries
|
171 |
+
if not all(isinstance(item, dict) for item in parsed_output):
|
172 |
+
raise ValueError("Not all items in the parsed output are dictionaries.")
|
173 |
+
|
174 |
+
return parsed_output
|
175 |
+
|
176 |
+
except Exception as e:
|
177 |
+
print(f"Error parsing output: {e}")
|
178 |
+
return None
|
computer_use_demo/gui_agent/actor/uitars_agent.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import re
|
3 |
+
from openai import OpenAI
|
4 |
+
|
5 |
+
from computer_use_demo.gui_agent.llm_utils.oai import encode_image
|
6 |
+
from computer_use_demo.tools.screen_capture import get_screenshot
|
7 |
+
from computer_use_demo.tools.logger import logger, truncate_string
|
8 |
+
|
9 |
+
|
10 |
+
class UITARS_Actor:
|
11 |
+
"""
|
12 |
+
In OOTB, we use the default grounding system prompt form UI_TARS repo, and then convert its action to our action format.
|
13 |
+
"""
|
14 |
+
|
15 |
+
_NAV_SYSTEM_GROUNDING = """
|
16 |
+
You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
17 |
+
|
18 |
+
## Output Format
|
19 |
+
```Action: ...```
|
20 |
+
|
21 |
+
## Action Space
|
22 |
+
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
23 |
+
hotkey(key='')
|
24 |
+
type(content='') #If you want to submit your input, use \"\" at the end of `content`.
|
25 |
+
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
|
26 |
+
wait() #Sleep for 5s and take a screenshot to check for any changes.
|
27 |
+
finished()
|
28 |
+
call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help.
|
29 |
+
|
30 |
+
## Note
|
31 |
+
- Do not generate any other text.
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(self, ui_tars_url, output_callback, api_key="", selected_screen=0):
|
35 |
+
|
36 |
+
self.ui_tars_url = ui_tars_url
|
37 |
+
self.ui_tars_client = OpenAI(base_url=self.ui_tars_url, api_key=api_key)
|
38 |
+
self.selected_screen = selected_screen
|
39 |
+
self.output_callback = output_callback
|
40 |
+
|
41 |
+
self.grounding_system_prompt = self._NAV_SYSTEM_GROUNDING.format()
|
42 |
+
|
43 |
+
|
44 |
+
def __call__(self, messages):
|
45 |
+
|
46 |
+
task = messages
|
47 |
+
|
48 |
+
# take screenshot
|
49 |
+
screenshot, screenshot_path = get_screenshot(selected_screen=self.selected_screen, resize=True, target_width=1920, target_height=1080)
|
50 |
+
screenshot_path = str(screenshot_path)
|
51 |
+
screenshot_base64 = encode_image(screenshot_path)
|
52 |
+
|
53 |
+
logger.info(f"Sending messages to UI-TARS on {self.ui_tars_url}: {task}, screenshot: {screenshot_path}")
|
54 |
+
|
55 |
+
response = self.ui_tars_client.chat.completions.create(
|
56 |
+
model="ui-tars",
|
57 |
+
messages=[
|
58 |
+
{"role": "system", "content": self.grounding_system_prompt},
|
59 |
+
{"role": "user", "content": [
|
60 |
+
{"type": "text", "text": task},
|
61 |
+
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{screenshot_base64}"}}
|
62 |
+
]
|
63 |
+
},
|
64 |
+
],
|
65 |
+
max_tokens=256,
|
66 |
+
temperature=0
|
67 |
+
)
|
68 |
+
|
69 |
+
ui_tars_action = response.choices[0].message.content
|
70 |
+
converted_action = convert_ui_tars_action_to_json(ui_tars_action)
|
71 |
+
response = str(converted_action)
|
72 |
+
|
73 |
+
response = {'content': response, 'role': 'assistant'}
|
74 |
+
return response
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
def convert_ui_tars_action_to_json(action_str: str) -> str:
|
79 |
+
"""
|
80 |
+
Converts an action line such as:
|
81 |
+
Action: click(start_box='(153,97)')
|
82 |
+
into a JSON string of the form:
|
83 |
+
{
|
84 |
+
"action": "CLICK",
|
85 |
+
"value": null,
|
86 |
+
"position": [153, 97]
|
87 |
+
}
|
88 |
+
"""
|
89 |
+
|
90 |
+
# Strip leading/trailing whitespace and remove "Action: " prefix if present
|
91 |
+
action_str = action_str.strip()
|
92 |
+
if action_str.startswith("Action:"):
|
93 |
+
action_str = action_str[len("Action:"):].strip()
|
94 |
+
|
95 |
+
# Mappings from old action names to the new action schema
|
96 |
+
ACTION_MAP = {
|
97 |
+
"click": "CLICK",
|
98 |
+
"type": "INPUT",
|
99 |
+
"scroll": "SCROLL",
|
100 |
+
"wait": "STOP", # TODO: deal with "wait()"
|
101 |
+
"finished": "STOP",
|
102 |
+
"call_user": "STOP",
|
103 |
+
"hotkey": "HOTKEY", # We break down the actual key below (Enter, Esc, etc.)
|
104 |
+
}
|
105 |
+
|
106 |
+
# Prepare a structure for the final JSON
|
107 |
+
# Default to no position and null value
|
108 |
+
output_dict = {
|
109 |
+
"action": None,
|
110 |
+
"value": None,
|
111 |
+
"position": None
|
112 |
+
}
|
113 |
+
|
114 |
+
# 1) CLICK(...) e.g. click(start_box='(153,97)')
|
115 |
+
match_click = re.match(r"^click\(start_box='\(?(\d+),\s*(\d+)\)?'\)$", action_str)
|
116 |
+
if match_click:
|
117 |
+
x, y = match_click.groups()
|
118 |
+
output_dict["action"] = ACTION_MAP["click"]
|
119 |
+
output_dict["position"] = [int(x), int(y)]
|
120 |
+
return json.dumps(output_dict)
|
121 |
+
|
122 |
+
# 2) HOTKEY(...) e.g. hotkey(key='Enter')
|
123 |
+
match_hotkey = re.match(r"^hotkey\(key='([^']+)'\)$", action_str)
|
124 |
+
if match_hotkey:
|
125 |
+
key = match_hotkey.group(1).lower()
|
126 |
+
if key == "enter":
|
127 |
+
output_dict["action"] = "ENTER"
|
128 |
+
elif key == "esc":
|
129 |
+
output_dict["action"] = "ESC"
|
130 |
+
else:
|
131 |
+
# Otherwise treat it as some generic hotkey
|
132 |
+
output_dict["action"] = ACTION_MAP["hotkey"]
|
133 |
+
output_dict["value"] = key
|
134 |
+
return json.dumps(output_dict)
|
135 |
+
|
136 |
+
# 3) TYPE(...) e.g. type(content='some text')
|
137 |
+
match_type = re.match(r"^type\(content='([^']*)'\)$", action_str)
|
138 |
+
if match_type:
|
139 |
+
typed_content = match_type.group(1)
|
140 |
+
output_dict["action"] = ACTION_MAP["type"]
|
141 |
+
output_dict["value"] = typed_content
|
142 |
+
# If you want a position (x,y) you need it in your string. Otherwise it's omitted.
|
143 |
+
return json.dumps(output_dict)
|
144 |
+
|
145 |
+
# 4) SCROLL(...) e.g. scroll(start_box='(153,97)', direction='down')
|
146 |
+
# or scroll(start_box='...', direction='down')
|
147 |
+
match_scroll = re.match(
|
148 |
+
r"^scroll\(start_box='[^']*'\s*,\s*direction='(down|up|left|right)'\)$",
|
149 |
+
action_str
|
150 |
+
)
|
151 |
+
if match_scroll:
|
152 |
+
direction = match_scroll.group(1)
|
153 |
+
output_dict["action"] = ACTION_MAP["scroll"]
|
154 |
+
output_dict["value"] = direction
|
155 |
+
return json.dumps(output_dict)
|
156 |
+
|
157 |
+
# 5) WAIT() or FINISHED() or CALL_USER() etc.
|
158 |
+
if action_str in ["wait()", "finished()", "call_user()"]:
|
159 |
+
base_action = action_str.replace("()", "")
|
160 |
+
if base_action in ACTION_MAP:
|
161 |
+
output_dict["action"] = ACTION_MAP[base_action]
|
162 |
+
else:
|
163 |
+
output_dict["action"] = "STOP"
|
164 |
+
return json.dumps(output_dict)
|
165 |
+
|
166 |
+
# If none of the above patterns matched, you can decide how to handle
|
167 |
+
# unknown or unexpected action lines:
|
168 |
+
output_dict["action"] = "STOP"
|
169 |
+
return json.dumps(output_dict)
|
computer_use_demo/gui_agent/llm_utils/llm_utils.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import ast
|
4 |
+
import base64
|
5 |
+
|
6 |
+
|
7 |
+
def is_image_path(text):
|
8 |
+
# Checking if the input text ends with typical image file extensions
|
9 |
+
image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".tif")
|
10 |
+
if text.endswith(image_extensions):
|
11 |
+
return True
|
12 |
+
else:
|
13 |
+
return False
|
14 |
+
|
15 |
+
|
16 |
+
def encode_image(image_path):
|
17 |
+
"""Encode image file to base64."""
|
18 |
+
with open(image_path, "rb") as image_file:
|
19 |
+
return base64.b64encode(image_file.read()).decode("utf-8")
|
20 |
+
|
21 |
+
|
22 |
+
def is_url_or_filepath(input_string):
|
23 |
+
# Check if input_string is a URL
|
24 |
+
url_pattern = re.compile(
|
25 |
+
r"http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+"
|
26 |
+
)
|
27 |
+
if url_pattern.match(input_string):
|
28 |
+
return "URL"
|
29 |
+
|
30 |
+
# Check if input_string is a file path
|
31 |
+
file_path = os.path.abspath(input_string)
|
32 |
+
if os.path.exists(file_path):
|
33 |
+
return "File path"
|
34 |
+
|
35 |
+
return "Invalid"
|
36 |
+
|
37 |
+
|
38 |
+
def extract_data(input_string, data_type):
|
39 |
+
# Regular expression to extract content starting from '```python' until the end if there are no closing backticks
|
40 |
+
pattern = f"```{data_type}" + r"(.*?)(```|$)"
|
41 |
+
# Extract content
|
42 |
+
# re.DOTALL allows '.' to match newlines as well
|
43 |
+
matches = re.findall(pattern, input_string, re.DOTALL)
|
44 |
+
# Return the first match if exists, trimming whitespace and ignoring potential closing backticks
|
45 |
+
return matches[0][0].strip() if matches else input_string
|
46 |
+
|
47 |
+
|
48 |
+
def parse_input(code):
|
49 |
+
"""Use AST to parse the input string and extract the function name, arguments, and keyword arguments."""
|
50 |
+
|
51 |
+
def get_target_names(target):
|
52 |
+
"""Recursively get all variable names from the assignment target."""
|
53 |
+
if isinstance(target, ast.Name):
|
54 |
+
return [target.id]
|
55 |
+
elif isinstance(target, ast.Tuple):
|
56 |
+
names = []
|
57 |
+
for elt in target.elts:
|
58 |
+
names.extend(get_target_names(elt))
|
59 |
+
return names
|
60 |
+
return []
|
61 |
+
|
62 |
+
def extract_value(node):
|
63 |
+
"""提取 AST 节点的实际值"""
|
64 |
+
if isinstance(node, ast.Constant):
|
65 |
+
return node.value
|
66 |
+
elif isinstance(node, ast.Name):
|
67 |
+
# TODO: a better way to handle variables
|
68 |
+
raise ValueError(
|
69 |
+
f"Arguments should be a Constant, got a variable {node.id} instead."
|
70 |
+
)
|
71 |
+
# 添加其他需要处理的 AST 节点类型
|
72 |
+
return None
|
73 |
+
|
74 |
+
try:
|
75 |
+
tree = ast.parse(code)
|
76 |
+
for node in ast.walk(tree):
|
77 |
+
if isinstance(node, ast.Assign):
|
78 |
+
targets = []
|
79 |
+
for t in node.targets:
|
80 |
+
targets.extend(get_target_names(t))
|
81 |
+
if isinstance(node.value, ast.Call):
|
82 |
+
func_name = node.value.func.id
|
83 |
+
args = [ast.dump(arg) for arg in node.value.args]
|
84 |
+
kwargs = {
|
85 |
+
kw.arg: extract_value(kw.value) for kw in node.value.keywords
|
86 |
+
}
|
87 |
+
print(f"Input: {code.strip()}")
|
88 |
+
print(f"Output Variables: {targets}")
|
89 |
+
print(f"Function Name: {func_name}")
|
90 |
+
print(f"Arguments: {args}")
|
91 |
+
print(f"Keyword Arguments: {kwargs}")
|
92 |
+
elif isinstance(node, ast.Expr) and isinstance(node.value, ast.Call):
|
93 |
+
targets = []
|
94 |
+
func_name = extract_value(node.value.func)
|
95 |
+
args = [extract_value(arg) for arg in node.value.args]
|
96 |
+
kwargs = {kw.arg: extract_value(kw.value) for kw in node.value.keywords}
|
97 |
+
|
98 |
+
except SyntaxError:
|
99 |
+
print(f"Input: {code.strip()}")
|
100 |
+
print("No match found")
|
101 |
+
|
102 |
+
return targets, func_name, args, kwargs
|
103 |
+
|
104 |
+
|
105 |
+
if __name__ == "__main__":
|
106 |
+
import json
|
107 |
+
s='{"Thinking": "The Docker icon has been successfully clicked, and the Docker application should now be opening. No further actions are required.", "Next Action": None}'
|
108 |
+
json_str = json.loads(s)
|
109 |
+
print(json_str)
|
computer_use_demo/gui_agent/llm_utils/oai.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
import base64
|
4 |
+
import requests
|
5 |
+
from computer_use_demo.gui_agent.llm_utils.llm_utils import is_image_path, encode_image
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
def run_oai_interleaved(messages: list, system: str, llm: str, api_key: str, max_tokens=256, temperature=0):
|
10 |
+
|
11 |
+
api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
12 |
+
if not api_key:
|
13 |
+
raise ValueError("OPENAI_API_KEY is not set")
|
14 |
+
|
15 |
+
headers = {"Content-Type": "application/json",
|
16 |
+
"Authorization": f"Bearer {api_key}"}
|
17 |
+
|
18 |
+
final_messages = [{"role": "system", "content": system}]
|
19 |
+
|
20 |
+
# image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
21 |
+
if type(messages) == list:
|
22 |
+
for item in messages:
|
23 |
+
contents = []
|
24 |
+
if isinstance(item, dict):
|
25 |
+
for cnt in item["content"]:
|
26 |
+
if isinstance(cnt, str):
|
27 |
+
if is_image_path(cnt):
|
28 |
+
base64_image = encode_image(cnt)
|
29 |
+
content = {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
|
30 |
+
# content = {"type": "image_url", "image_url": {"url": image_url}}
|
31 |
+
else:
|
32 |
+
content = {"type": "text", "text": cnt}
|
33 |
+
contents.append(content)
|
34 |
+
|
35 |
+
message = {"role": item["role"], "content": contents}
|
36 |
+
else: # str
|
37 |
+
contents.append({"type": "text", "text": item})
|
38 |
+
message = {"role": "user", "content": contents}
|
39 |
+
|
40 |
+
final_messages.append(message)
|
41 |
+
|
42 |
+
|
43 |
+
elif isinstance(messages, str):
|
44 |
+
final_messages = [{"role": "user", "content": messages}]
|
45 |
+
|
46 |
+
print("[oai] sending messages:", final_messages)
|
47 |
+
|
48 |
+
payload = {
|
49 |
+
"model": llm,
|
50 |
+
"messages": final_messages,
|
51 |
+
"max_tokens": max_tokens,
|
52 |
+
"temperature": temperature,
|
53 |
+
# "stop": stop,
|
54 |
+
}
|
55 |
+
|
56 |
+
# from IPython.core.debugger import Pdb; Pdb().set_trace()
|
57 |
+
|
58 |
+
response = requests.post(
|
59 |
+
"https://api.openai.com/v1/chat/completions", headers=headers, json=payload
|
60 |
+
)
|
61 |
+
|
62 |
+
try:
|
63 |
+
text = response.json()['choices'][0]['message']['content']
|
64 |
+
token_usage = int(response.json()['usage']['total_tokens'])
|
65 |
+
return text, token_usage
|
66 |
+
|
67 |
+
# return error message if the response is not successful
|
68 |
+
except Exception as e:
|
69 |
+
print(f"Error in interleaved openAI: {e}. This may due to your invalid OPENAI_API_KEY. Please check the response: {response.json()} ")
|
70 |
+
return response.json()
|
71 |
+
|
72 |
+
def run_ssh_llm_interleaved(messages: list, system: str, llm: str, ssh_host: str, ssh_port: int, max_tokens=256, temperature=0.7, do_sample=True):
|
73 |
+
"""Send chat completion request to SSH remote server"""
|
74 |
+
from PIL import Image
|
75 |
+
from io import BytesIO
|
76 |
+
def encode_image(image_path: str, max_size=1024) -> str:
|
77 |
+
"""Convert image to base64 encoding with preprocessing"""
|
78 |
+
try:
|
79 |
+
with Image.open(image_path) as img:
|
80 |
+
# Convert to RGB format
|
81 |
+
img = img.convert('RGB')
|
82 |
+
|
83 |
+
# Scale down if image is too large
|
84 |
+
if max(img.size) > max_size:
|
85 |
+
ratio = max_size / max(img.size)
|
86 |
+
new_size = tuple(int(dim * ratio) for dim in img.size)
|
87 |
+
img = img.resize(new_size, Image.LANCZOS)
|
88 |
+
|
89 |
+
# Convert processed image to base64
|
90 |
+
buffered = BytesIO()
|
91 |
+
img.save(buffered, format="JPEG", quality=85)
|
92 |
+
img_str = base64.b64encode(buffered.getvalue()).decode()
|
93 |
+
return img_str
|
94 |
+
except Exception as e:
|
95 |
+
print(f"Image processing failed: {str(e)}")
|
96 |
+
raise
|
97 |
+
|
98 |
+
|
99 |
+
try:
|
100 |
+
# Verify SSH connection info
|
101 |
+
if not ssh_host or not ssh_port:
|
102 |
+
raise ValueError("SSH_HOST and SSH_PORT are not set")
|
103 |
+
|
104 |
+
# Build API URL
|
105 |
+
api_url = f"http://{ssh_host}:{ssh_port}"
|
106 |
+
|
107 |
+
# Prepare message list
|
108 |
+
final_messages = []
|
109 |
+
|
110 |
+
# Add system message
|
111 |
+
if system:
|
112 |
+
final_messages.append({
|
113 |
+
"role": "system",
|
114 |
+
"content": system
|
115 |
+
})
|
116 |
+
|
117 |
+
# Process user messages
|
118 |
+
if type(messages) == list:
|
119 |
+
for item in messages:
|
120 |
+
contents = []
|
121 |
+
if isinstance(item, dict):
|
122 |
+
for cnt in item["content"]:
|
123 |
+
if isinstance(cnt, str):
|
124 |
+
if is_image_path(cnt):
|
125 |
+
base64_image = encode_image(cnt)
|
126 |
+
content = {
|
127 |
+
"type": "image_url",
|
128 |
+
"image_url": {
|
129 |
+
"url": f"data:image/jpeg;base64,{base64_image}"
|
130 |
+
}
|
131 |
+
}
|
132 |
+
else:
|
133 |
+
content = {
|
134 |
+
"type": "text",
|
135 |
+
"text": cnt
|
136 |
+
}
|
137 |
+
contents.append(content)
|
138 |
+
message = {"role": item["role"], "content": contents}
|
139 |
+
else: # str
|
140 |
+
contents.append({"type": "text", "text": item})
|
141 |
+
message = {"role": "user", "content": contents}
|
142 |
+
final_messages.append(message)
|
143 |
+
elif isinstance(messages, str):
|
144 |
+
final_messages.append({
|
145 |
+
"role": "user",
|
146 |
+
"content": messages
|
147 |
+
})
|
148 |
+
|
149 |
+
# Prepare request data
|
150 |
+
data = {
|
151 |
+
"model": llm,
|
152 |
+
"messages": final_messages,
|
153 |
+
"temperature": temperature,
|
154 |
+
"max_tokens": max_tokens,
|
155 |
+
"do_sample": do_sample
|
156 |
+
}
|
157 |
+
|
158 |
+
print(f"[ssh] Sending chat completion request to model: {llm}")
|
159 |
+
print(f"[ssh] sending messages:", final_messages)
|
160 |
+
|
161 |
+
# Send request
|
162 |
+
response = requests.post(
|
163 |
+
f"{api_url}/v1/chat/completions",
|
164 |
+
json=data,
|
165 |
+
headers={"Content-Type": "application/json"},
|
166 |
+
timeout=30
|
167 |
+
)
|
168 |
+
|
169 |
+
result = response.json()
|
170 |
+
|
171 |
+
if response.status_code == 200:
|
172 |
+
content = result['choices'][0]['message']['content']
|
173 |
+
token_usage = int(result['usage']['total_tokens'])
|
174 |
+
print(f"[ssh] Generation successful: {content}")
|
175 |
+
return content, token_usage
|
176 |
+
else:
|
177 |
+
print(f"[ssh] Request failed: {result}")
|
178 |
+
raise Exception(f"API request failed: {result}")
|
179 |
+
|
180 |
+
except Exception as e:
|
181 |
+
print(f"[ssh] Chat completion request failed: {str(e)}")
|
182 |
+
raise
|
183 |
+
|
184 |
+
|
185 |
+
|
186 |
+
if __name__ == "__main__":
|
187 |
+
|
188 |
+
api_key = os.environ.get("OPENAI_API_KEY")
|
189 |
+
if not api_key:
|
190 |
+
raise ValueError("OPENAI_API_KEY is not set")
|
191 |
+
|
192 |
+
# text, token_usage = run_oai_interleaved(
|
193 |
+
# messages= [{"content": [
|
194 |
+
# "What is in the screenshot?",
|
195 |
+
# "./tmp/outputs/screenshot_0b04acbb783d4706bc93873d17ba8c05.png"],
|
196 |
+
# "role": "user"
|
197 |
+
# }],
|
198 |
+
# llm="gpt-4o-mini",
|
199 |
+
# system="You are a helpful assistant",
|
200 |
+
# api_key=api_key,
|
201 |
+
# max_tokens=256,
|
202 |
+
# temperature=0)
|
203 |
+
|
204 |
+
# print(text, token_usage)
|
205 |
+
text, token_usage = run_ssh_llm_interleaved(
|
206 |
+
messages= [{"content": [
|
207 |
+
"What is in the screenshot?",
|
208 |
+
"tmp/outputs/screenshot_5a26d36c59e84272ab58c1b34493d40d.png"],
|
209 |
+
"role": "user"
|
210 |
+
}],
|
211 |
+
llm="Qwen2.5-VL-7B-Instruct",
|
212 |
+
ssh_host="10.245.92.68",
|
213 |
+
ssh_port=9192,
|
214 |
+
max_tokens=256,
|
215 |
+
temperature=0.7
|
216 |
+
)
|
217 |
+
print(text, token_usage)
|
218 |
+
# There is an introduction describing the Calyx... 36986
|
computer_use_demo/gui_agent/llm_utils/qwen.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import logging
|
4 |
+
import base64
|
5 |
+
import requests
|
6 |
+
|
7 |
+
import dashscope
|
8 |
+
# from computer_use_demo.gui_agent.llm_utils import is_image_path, encode_image
|
9 |
+
|
10 |
+
def is_image_path(text):
|
11 |
+
return False
|
12 |
+
|
13 |
+
def encode_image(image_path):
|
14 |
+
return ""
|
15 |
+
|
16 |
+
|
17 |
+
def run_qwen(messages: list, system: str, llm: str, api_key: str, max_tokens=256, temperature=0):
|
18 |
+
|
19 |
+
api_key = api_key or os.environ.get("QWEN_API_KEY")
|
20 |
+
if not api_key:
|
21 |
+
raise ValueError("QWEN_API_KEY is not set")
|
22 |
+
|
23 |
+
dashscope.api_key = api_key
|
24 |
+
|
25 |
+
# from IPython.core.debugger import Pdb; Pdb().set_trace()
|
26 |
+
|
27 |
+
final_messages = [{"role": "system", "content": [{"text": system}]}]
|
28 |
+
# image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
29 |
+
if type(messages) == list:
|
30 |
+
for item in messages:
|
31 |
+
contents = []
|
32 |
+
if isinstance(item, dict):
|
33 |
+
for cnt in item["content"]:
|
34 |
+
if isinstance(cnt, str):
|
35 |
+
if is_image_path(cnt):
|
36 |
+
# base64_image = encode_image(cnt)
|
37 |
+
content = [{"image": cnt}]
|
38 |
+
# content = {"type": "image_url", "image_url": {"url": image_url}}
|
39 |
+
else:
|
40 |
+
content = {"text": cnt}
|
41 |
+
contents.append(content)
|
42 |
+
|
43 |
+
message = {"role": item["role"], "content": contents}
|
44 |
+
else: # str
|
45 |
+
contents.append({"text": item})
|
46 |
+
message = {"role": "user", "content": contents}
|
47 |
+
|
48 |
+
final_messages.append(message)
|
49 |
+
|
50 |
+
print("[qwen-vl] sending messages:", final_messages)
|
51 |
+
|
52 |
+
response = dashscope.MultiModalConversation.call(
|
53 |
+
model='qwen-vl-max-latest',
|
54 |
+
# model='qwen-vl-max-0809',
|
55 |
+
messages=final_messages
|
56 |
+
)
|
57 |
+
|
58 |
+
# from IPython.core.debugger import Pdb; Pdb().set_trace()
|
59 |
+
|
60 |
+
try:
|
61 |
+
text = response.output.choices[0].message.content[0]['text']
|
62 |
+
usage = response.usage
|
63 |
+
|
64 |
+
if "total_tokens" not in usage:
|
65 |
+
token_usage = int(usage["input_tokens"] + usage["output_tokens"])
|
66 |
+
else:
|
67 |
+
token_usage = int(usage["total_tokens"])
|
68 |
+
|
69 |
+
return text, token_usage
|
70 |
+
# return response.json()['choices'][0]['message']['content']
|
71 |
+
# return error message if the response is not successful
|
72 |
+
except Exception as e:
|
73 |
+
print(f"Error in interleaved openAI: {e}. This may due to your invalid OPENAI_API_KEY. Please check the response: {response.json()} ")
|
74 |
+
return response.json()
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
if __name__ == "__main__":
|
79 |
+
api_key = os.environ.get("QWEN_API_KEY")
|
80 |
+
if not api_key:
|
81 |
+
raise ValueError("QWEN_API_KEY is not set")
|
82 |
+
|
83 |
+
dashscope.api_key = api_key
|
84 |
+
|
85 |
+
final_messages = [{"role": "user",
|
86 |
+
"content": [
|
87 |
+
{"text": "What is in the screenshot?"},
|
88 |
+
{"image": "./tmp/outputs/screenshot_0b04acbb783d4706bc93873d17ba8c05.png"}
|
89 |
+
]
|
90 |
+
}
|
91 |
+
]
|
92 |
+
response = dashscope.MultiModalConversation.call(model='qwen-vl-max-0809', messages=final_messages)
|
93 |
+
|
94 |
+
print(response)
|
95 |
+
|
96 |
+
text = response.output.choices[0].message.content[0]['text']
|
97 |
+
usage = response.usage
|
98 |
+
|
99 |
+
if "total_tokens" not in usage:
|
100 |
+
if "image_tokens" in usage:
|
101 |
+
token_usage = usage["input_tokens"] + usage["output_tokens"] + usage["image_tokens"]
|
102 |
+
else:
|
103 |
+
token_usage = usage["input_tokens"] + usage["output_tokens"]
|
104 |
+
else:
|
105 |
+
token_usage = usage["total_tokens"]
|
106 |
+
|
107 |
+
print(text, token_usage)
|
108 |
+
# The screenshot is from a video game... 1387
|
computer_use_demo/gui_agent/llm_utils/run_llm.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import logging
|
3 |
+
from .oai import run_oai_interleaved
|
4 |
+
from .gemini import run_gemini_interleaved
|
5 |
+
|
6 |
+
def run_llm(prompt, llm="gpt-4o-mini", max_tokens=256, temperature=0, stop=None):
|
7 |
+
log_prompt(prompt)
|
8 |
+
|
9 |
+
# turn string prompt into list
|
10 |
+
if isinstance(prompt, str):
|
11 |
+
prompt = [prompt]
|
12 |
+
elif isinstance(prompt, list):
|
13 |
+
pass
|
14 |
+
else:
|
15 |
+
raise ValueError(f"Invalid prompt type: {type(prompt)}")
|
16 |
+
|
17 |
+
if llm.startswith("gpt"): # gpt series
|
18 |
+
out = run_oai_interleaved(
|
19 |
+
prompt,
|
20 |
+
llm,
|
21 |
+
max_tokens,
|
22 |
+
temperature,
|
23 |
+
stop
|
24 |
+
)
|
25 |
+
elif llm.startswith("gemini"): # gemini series
|
26 |
+
out = run_gemini_interleaved(
|
27 |
+
prompt,
|
28 |
+
llm,
|
29 |
+
max_tokens,
|
30 |
+
temperature,
|
31 |
+
stop
|
32 |
+
)
|
33 |
+
else:
|
34 |
+
raise ValueError(f"Invalid llm: {llm}")
|
35 |
+
logging.info(
|
36 |
+
f"========Output for {llm}=======\n{out}\n============================")
|
37 |
+
return out
|
38 |
+
|
39 |
+
def log_prompt(prompt):
|
40 |
+
prompt_display = [prompt] if isinstance(prompt, str) else prompt
|
41 |
+
prompt_display = "\n\n".join(prompt_display)
|
42 |
+
logging.info(
|
43 |
+
f"========Prompt=======\n{prompt_display}\n============================")
|
44 |
+
|
computer_use_demo/gui_agent/planner/anthropic_agent.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Agentic sampling loop that calls the Anthropic API and local implementation of anthropic-defined computer use tools.
|
3 |
+
"""
|
4 |
+
import asyncio
|
5 |
+
import platform
|
6 |
+
from collections.abc import Callable
|
7 |
+
from datetime import datetime
|
8 |
+
from enum import StrEnum
|
9 |
+
from typing import Any, cast
|
10 |
+
|
11 |
+
from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex, APIResponse
|
12 |
+
from anthropic.types import (
|
13 |
+
ToolResultBlockParam,
|
14 |
+
)
|
15 |
+
from anthropic.types.beta import (
|
16 |
+
BetaContentBlock,
|
17 |
+
BetaContentBlockParam,
|
18 |
+
BetaImageBlockParam,
|
19 |
+
BetaMessage,
|
20 |
+
BetaMessageParam,
|
21 |
+
BetaTextBlockParam,
|
22 |
+
BetaToolResultBlockParam,
|
23 |
+
)
|
24 |
+
from anthropic.types import TextBlock
|
25 |
+
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock
|
26 |
+
|
27 |
+
from computer_use_demo.tools import BashTool, ComputerTool, EditTool, ToolCollection, ToolResult
|
28 |
+
|
29 |
+
from PIL import Image
|
30 |
+
from io import BytesIO
|
31 |
+
import gradio as gr
|
32 |
+
from typing import Dict
|
33 |
+
|
34 |
+
|
35 |
+
BETA_FLAG = "computer-use-2024-10-22"
|
36 |
+
|
37 |
+
|
38 |
+
class APIProvider(StrEnum):
|
39 |
+
ANTHROPIC = "anthropic"
|
40 |
+
BEDROCK = "bedrock"
|
41 |
+
VERTEX = "vertex"
|
42 |
+
|
43 |
+
|
44 |
+
PROVIDER_TO_DEFAULT_MODEL_NAME: dict[APIProvider, str] = {
|
45 |
+
APIProvider.ANTHROPIC: "claude-3-5-sonnet-20241022",
|
46 |
+
APIProvider.BEDROCK: "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
47 |
+
APIProvider.VERTEX: "claude-3-5-sonnet-v2@20241022",
|
48 |
+
}
|
49 |
+
|
50 |
+
|
51 |
+
# Check OS
|
52 |
+
SYSTEM_PROMPT = f"""<SYSTEM_CAPABILITY>
|
53 |
+
* You are utilizing a Windows system with internet access.
|
54 |
+
* The current date is {datetime.today().strftime('%A, %B %d, %Y')}.
|
55 |
+
</SYSTEM_CAPABILITY>
|
56 |
+
"""
|
57 |
+
|
58 |
+
|
59 |
+
class AnthropicActor:
|
60 |
+
def __init__(
|
61 |
+
self,
|
62 |
+
model: str,
|
63 |
+
provider: APIProvider,
|
64 |
+
system_prompt_suffix: str,
|
65 |
+
api_key: str,
|
66 |
+
api_response_callback: Callable[[APIResponse[BetaMessage]], None],
|
67 |
+
max_tokens: int = 4096,
|
68 |
+
only_n_most_recent_images: int | None = None,
|
69 |
+
selected_screen: int = 0,
|
70 |
+
print_usage: bool = True,
|
71 |
+
):
|
72 |
+
self.model = model
|
73 |
+
self.provider = provider
|
74 |
+
self.system_prompt_suffix = system_prompt_suffix
|
75 |
+
self.api_key = api_key
|
76 |
+
self.api_response_callback = api_response_callback
|
77 |
+
self.max_tokens = max_tokens
|
78 |
+
self.only_n_most_recent_images = only_n_most_recent_images
|
79 |
+
self.selected_screen = selected_screen
|
80 |
+
|
81 |
+
self.tool_collection = ToolCollection(
|
82 |
+
ComputerTool(selected_screen=selected_screen),
|
83 |
+
BashTool(),
|
84 |
+
EditTool(),
|
85 |
+
)
|
86 |
+
|
87 |
+
self.system = (
|
88 |
+
f"{SYSTEM_PROMPT}{' ' + system_prompt_suffix if system_prompt_suffix else ''}"
|
89 |
+
)
|
90 |
+
|
91 |
+
self.total_token_usage = 0
|
92 |
+
self.total_cost = 0
|
93 |
+
self.print_usage = print_usage
|
94 |
+
|
95 |
+
# Instantiate the appropriate API client based on the provider
|
96 |
+
if provider == APIProvider.ANTHROPIC:
|
97 |
+
self.client = Anthropic(api_key=api_key)
|
98 |
+
elif provider == APIProvider.VERTEX:
|
99 |
+
self.client = AnthropicVertex()
|
100 |
+
elif provider == APIProvider.BEDROCK:
|
101 |
+
self.client = AnthropicBedrock()
|
102 |
+
|
103 |
+
def __call__(
|
104 |
+
self,
|
105 |
+
*,
|
106 |
+
messages: list[BetaMessageParam]
|
107 |
+
):
|
108 |
+
"""
|
109 |
+
Generate a response given history messages.
|
110 |
+
"""
|
111 |
+
if self.only_n_most_recent_images:
|
112 |
+
_maybe_filter_to_n_most_recent_images(messages, self.only_n_most_recent_images)
|
113 |
+
|
114 |
+
# Call the API synchronously
|
115 |
+
raw_response = self.client.beta.messages.with_raw_response.create(
|
116 |
+
max_tokens=self.max_tokens,
|
117 |
+
messages=messages,
|
118 |
+
model=self.model,
|
119 |
+
system=self.system,
|
120 |
+
tools=self.tool_collection.to_params(),
|
121 |
+
betas=["computer-use-2024-10-22"],
|
122 |
+
)
|
123 |
+
|
124 |
+
self.api_response_callback(cast(APIResponse[BetaMessage], raw_response))
|
125 |
+
|
126 |
+
response = raw_response.parse()
|
127 |
+
print(f"AnthropicActor response: {response}")
|
128 |
+
|
129 |
+
self.total_token_usage += response.usage.input_tokens + response.usage.output_tokens
|
130 |
+
self.total_cost += (response.usage.input_tokens * 3 / 1000000 + response.usage.output_tokens * 15 / 1000000)
|
131 |
+
|
132 |
+
if self.print_usage:
|
133 |
+
print(f"Claude total token usage so far: {self.total_token_usage}, total cost so far: $USD{self.total_cost}")
|
134 |
+
|
135 |
+
return response
|
136 |
+
|
137 |
+
|
138 |
+
def _maybe_filter_to_n_most_recent_images(
|
139 |
+
messages: list[BetaMessageParam],
|
140 |
+
images_to_keep: int,
|
141 |
+
min_removal_threshold: int = 10,
|
142 |
+
):
|
143 |
+
"""
|
144 |
+
With the assumption that images are screenshots that are of diminishing value as
|
145 |
+
the conversation progresses, remove all but the final `images_to_keep` tool_result
|
146 |
+
images in place, with a chunk of min_removal_threshold to reduce the amount we
|
147 |
+
break the implicit prompt cache.
|
148 |
+
"""
|
149 |
+
if images_to_keep is None:
|
150 |
+
return messages
|
151 |
+
|
152 |
+
tool_result_blocks = cast(
|
153 |
+
list[ToolResultBlockParam],
|
154 |
+
[
|
155 |
+
item
|
156 |
+
for message in messages
|
157 |
+
for item in (
|
158 |
+
message["content"] if isinstance(message["content"], list) else []
|
159 |
+
)
|
160 |
+
if isinstance(item, dict) and item.get("type") == "tool_result"
|
161 |
+
],
|
162 |
+
)
|
163 |
+
|
164 |
+
total_images = sum(
|
165 |
+
1
|
166 |
+
for tool_result in tool_result_blocks
|
167 |
+
for content in tool_result.get("content", [])
|
168 |
+
if isinstance(content, dict) and content.get("type") == "image"
|
169 |
+
)
|
170 |
+
|
171 |
+
images_to_remove = total_images - images_to_keep
|
172 |
+
# for better cache behavior, we want to remove in chunks
|
173 |
+
images_to_remove -= images_to_remove % min_removal_threshold
|
174 |
+
|
175 |
+
for tool_result in tool_result_blocks:
|
176 |
+
if isinstance(tool_result.get("content"), list):
|
177 |
+
new_content = []
|
178 |
+
for content in tool_result.get("content", []):
|
179 |
+
if isinstance(content, dict) and content.get("type") == "image":
|
180 |
+
if images_to_remove > 0:
|
181 |
+
images_to_remove -= 1
|
182 |
+
continue
|
183 |
+
new_content.append(content)
|
184 |
+
tool_result["content"] = new_content
|
185 |
+
|
186 |
+
|
187 |
+
|
188 |
+
if __name__ == "__main__":
|
189 |
+
pass
|
190 |
+
# client = Anthropic(api_key="")
|
191 |
+
# response = client.beta.messages.with_raw_response.create(
|
192 |
+
# max_tokens=4096,
|
193 |
+
# model="claude-3-5-sonnet-20241022",
|
194 |
+
# system=SYSTEM_PROMPT,
|
195 |
+
# # tools=ToolCollection(
|
196 |
+
# # ComputerTool(selected_screen=0),
|
197 |
+
# # BashTool(),
|
198 |
+
# # EditTool(),
|
199 |
+
# # ).to_params(),
|
200 |
+
# betas=["computer-use-2024-10-22"],
|
201 |
+
# messages=[
|
202 |
+
# {"role": "user", "content": "click on (199, 199)."}
|
203 |
+
# ],
|
204 |
+
# )
|
205 |
+
|
206 |
+
# print(f"AnthropicActor response: {response.parse().usage.input_tokens+response.parse().usage.output_tokens}")
|
computer_use_demo/gui_agent/planner/api_vlm_planner.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import asyncio
|
3 |
+
import platform
|
4 |
+
from collections.abc import Callable
|
5 |
+
from datetime import datetime
|
6 |
+
from enum import StrEnum
|
7 |
+
from typing import Any, cast, Dict, Callable
|
8 |
+
|
9 |
+
from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex, APIResponse
|
10 |
+
from anthropic.types import TextBlock, ToolResultBlockParam
|
11 |
+
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock, BetaMessageParam
|
12 |
+
|
13 |
+
from computer_use_demo.tools.screen_capture import get_screenshot
|
14 |
+
from computer_use_demo.gui_agent.llm_utils.oai import run_oai_interleaved, run_ssh_llm_interleaved
|
15 |
+
from computer_use_demo.gui_agent.llm_utils.qwen import run_qwen
|
16 |
+
from computer_use_demo.gui_agent.llm_utils.llm_utils import extract_data, encode_image
|
17 |
+
from computer_use_demo.tools.colorful_text import colorful_text_showui, colorful_text_vlm
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
21 |
+
from qwen_vl_utils import process_vision_info
|
22 |
+
|
23 |
+
|
24 |
+
class APIVLMPlanner:
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
model: str,
|
28 |
+
provider: str,
|
29 |
+
system_prompt_suffix: str,
|
30 |
+
api_key: str,
|
31 |
+
output_callback: Callable,
|
32 |
+
api_response_callback: Callable,
|
33 |
+
max_tokens: int = 4096,
|
34 |
+
only_n_most_recent_images: int | None = None,
|
35 |
+
selected_screen: int = 0,
|
36 |
+
print_usage: bool = True,
|
37 |
+
device: torch.device = torch.device("cpu"),
|
38 |
+
):
|
39 |
+
self.device = device
|
40 |
+
if model == "gpt-4o":
|
41 |
+
self.model = "gpt-4o-2024-11-20"
|
42 |
+
elif model == "gpt-4o-mini":
|
43 |
+
self.model = "gpt-4o-mini" # "gpt-4o-mini"
|
44 |
+
elif model == "qwen2-vl-max":
|
45 |
+
self.model = "qwen2-vl-max"
|
46 |
+
elif model == "qwen2-vl-2b (ssh)":
|
47 |
+
self.model = "Qwen2-VL-2B-Instruct"
|
48 |
+
elif model == "qwen2-vl-7b (ssh)":
|
49 |
+
self.model = "Qwen2-VL-7B-Instruct"
|
50 |
+
elif model == "qwen2.5-vl-7b (ssh)":
|
51 |
+
self.model = "Qwen2.5-VL-7B-Instruct"
|
52 |
+
elif model == "qwen-vl-7b-instruct": # local model
|
53 |
+
self.model = "qwen-vl-7b-instruct"
|
54 |
+
self.min_pixels = 256 * 28 * 28
|
55 |
+
self.max_pixels = 1344 * 28 * 28
|
56 |
+
self.processor = AutoProcessor.from_pretrained(
|
57 |
+
"./Qwen2-VL-7B-Instruct",
|
58 |
+
min_pixels=self.min_pixels,
|
59 |
+
max_pixels=self.max_pixels
|
60 |
+
)
|
61 |
+
else:
|
62 |
+
raise ValueError(f"Model {model} not supported")
|
63 |
+
|
64 |
+
self.provider = provider
|
65 |
+
self.system_prompt_suffix = system_prompt_suffix
|
66 |
+
self.api_key = api_key
|
67 |
+
self.api_response_callback = api_response_callback
|
68 |
+
self.max_tokens = max_tokens
|
69 |
+
self.only_n_most_recent_images = only_n_most_recent_images
|
70 |
+
self.selected_screen = selected_screen
|
71 |
+
self.output_callback = output_callback
|
72 |
+
self.system_prompt = self._get_system_prompt() + self.system_prompt_suffix
|
73 |
+
|
74 |
+
|
75 |
+
self.print_usage = print_usage
|
76 |
+
self.total_token_usage = 0
|
77 |
+
self.total_cost = 0
|
78 |
+
|
79 |
+
|
80 |
+
def __call__(self, messages: list):
|
81 |
+
|
82 |
+
# drop looping actions msg, byte image etc
|
83 |
+
planner_messages = _message_filter_callback(messages)
|
84 |
+
print(f"filtered_messages: {planner_messages}")
|
85 |
+
|
86 |
+
if self.only_n_most_recent_images:
|
87 |
+
_maybe_filter_to_n_most_recent_images(planner_messages, self.only_n_most_recent_images)
|
88 |
+
|
89 |
+
# Take a screenshot
|
90 |
+
screenshot, screenshot_path = get_screenshot(selected_screen=self.selected_screen)
|
91 |
+
screenshot_path = str(screenshot_path)
|
92 |
+
image_base64 = encode_image(screenshot_path)
|
93 |
+
self.output_callback(f'Screenshot for {colorful_text_vlm}:\n<img src="data:image/png;base64,{image_base64}">',
|
94 |
+
sender="bot")
|
95 |
+
|
96 |
+
if isinstance(planner_messages[-1], dict):
|
97 |
+
if not isinstance(planner_messages[-1]["content"], list):
|
98 |
+
planner_messages[-1]["content"] = [planner_messages[-1]["content"]]
|
99 |
+
planner_messages[-1]["content"].append(screenshot_path)
|
100 |
+
|
101 |
+
print(f"Sending messages to VLMPlanner: {planner_messages}")
|
102 |
+
|
103 |
+
if self.model == "gpt-4o-2024-11-20":
|
104 |
+
vlm_response, token_usage = run_oai_interleaved(
|
105 |
+
messages=planner_messages,
|
106 |
+
system=self.system_prompt,
|
107 |
+
llm=self.model,
|
108 |
+
api_key=self.api_key,
|
109 |
+
max_tokens=self.max_tokens,
|
110 |
+
temperature=0,
|
111 |
+
)
|
112 |
+
print(f"oai token usage: {token_usage}")
|
113 |
+
self.total_token_usage += token_usage
|
114 |
+
self.total_cost += (token_usage * 0.15 / 1000000) # https://openai.com/api/pricing/
|
115 |
+
|
116 |
+
elif self.model == "qwen2-vl-max":
|
117 |
+
vlm_response, token_usage = run_qwen(
|
118 |
+
messages=planner_messages,
|
119 |
+
system=self.system_prompt,
|
120 |
+
llm=self.model,
|
121 |
+
api_key=self.api_key,
|
122 |
+
max_tokens=self.max_tokens,
|
123 |
+
temperature=0,
|
124 |
+
)
|
125 |
+
print(f"qwen token usage: {token_usage}")
|
126 |
+
self.total_token_usage += token_usage
|
127 |
+
self.total_cost += (token_usage * 0.02 / 7.25 / 1000) # 1USD=7.25CNY, https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-qianwen-vl-plus-api
|
128 |
+
elif "Qwen" in self.model:
|
129 |
+
# 从api_key中解析host和port
|
130 |
+
try:
|
131 |
+
ssh_host, ssh_port = self.api_key.split(":")
|
132 |
+
ssh_port = int(ssh_port)
|
133 |
+
except ValueError:
|
134 |
+
raise ValueError("Invalid SSH connection string. Expected format: host:port")
|
135 |
+
|
136 |
+
vlm_response, token_usage = run_ssh_llm_interleaved(
|
137 |
+
messages=planner_messages,
|
138 |
+
system=self.system_prompt,
|
139 |
+
llm=self.model,
|
140 |
+
ssh_host=ssh_host,
|
141 |
+
ssh_port=ssh_port,
|
142 |
+
max_tokens=self.max_tokens,
|
143 |
+
)
|
144 |
+
else:
|
145 |
+
raise ValueError(f"Model {self.model} not supported")
|
146 |
+
|
147 |
+
print(f"VLMPlanner response: {vlm_response}")
|
148 |
+
|
149 |
+
if self.print_usage:
|
150 |
+
print(f"VLMPlanner total token usage so far: {self.total_token_usage}. Total cost so far: $USD{self.total_cost:.5f}")
|
151 |
+
|
152 |
+
vlm_response_json = extract_data(vlm_response, "json")
|
153 |
+
|
154 |
+
# vlm_plan_str = '\n'.join([f'{key}: {value}' for key, value in json.loads(response).items()])
|
155 |
+
vlm_plan_str = ""
|
156 |
+
for key, value in json.loads(vlm_response_json).items():
|
157 |
+
if key == "Thinking":
|
158 |
+
vlm_plan_str += f'{value}'
|
159 |
+
else:
|
160 |
+
vlm_plan_str += f'\n{key}: {value}'
|
161 |
+
|
162 |
+
self.output_callback(f"{colorful_text_vlm}:\n{vlm_plan_str}", sender="bot")
|
163 |
+
|
164 |
+
return vlm_response_json
|
165 |
+
|
166 |
+
|
167 |
+
def _api_response_callback(self, response: APIResponse):
|
168 |
+
self.api_response_callback(response)
|
169 |
+
|
170 |
+
|
171 |
+
def reformat_messages(self, messages: list):
|
172 |
+
pass
|
173 |
+
|
174 |
+
def _get_system_prompt(self):
|
175 |
+
os_name = platform.system()
|
176 |
+
return f"""
|
177 |
+
You are using an {os_name} device.
|
178 |
+
You are able to use a mouse and keyboard to interact with the computer based on the given task and screenshot.
|
179 |
+
You can only interact with the desktop GUI (no terminal or application menu access).
|
180 |
+
|
181 |
+
You may be given some history plan and actions, this is the response from the previous loop.
|
182 |
+
You should carefully consider your plan base on the task, screenshot, and history actions.
|
183 |
+
|
184 |
+
Your available "Next Action" only include:
|
185 |
+
- ENTER: Press an enter key.
|
186 |
+
- ESCAPE: Press an ESCAPE key.
|
187 |
+
- INPUT: Input a string of text.
|
188 |
+
- CLICK: Describe the ui element to be clicked.
|
189 |
+
- HOVER: Describe the ui element to be hovered.
|
190 |
+
- SCROLL: Scroll the screen, you must specify up or down.
|
191 |
+
- PRESS: Describe the ui element to be pressed.
|
192 |
+
|
193 |
+
|
194 |
+
Output format:
|
195 |
+
```json
|
196 |
+
{{
|
197 |
+
"Thinking": str, # describe your thoughts on how to achieve the task, choose one action from available actions at a time.
|
198 |
+
"Next Action": "action_type, action description" | "None" # one action at a time, describe it in short and precisely.
|
199 |
+
}}
|
200 |
+
```
|
201 |
+
|
202 |
+
One Example:
|
203 |
+
```json
|
204 |
+
{{
|
205 |
+
"Thinking": "I need to search and navigate to amazon.com.",
|
206 |
+
"Next Action": "CLICK 'Search Google or type a URL'."
|
207 |
+
}}
|
208 |
+
```
|
209 |
+
|
210 |
+
IMPORTANT NOTES:
|
211 |
+
1. Carefully observe the screenshot to understand the current state and read history actions.
|
212 |
+
2. You should only give a single action at a time. for example, INPUT text, and ENTER can't be in one Next Action.
|
213 |
+
3. Attach the text to Next Action, if there is text or any description for the button.
|
214 |
+
4. You should not include other actions, such as keyboard shortcuts.
|
215 |
+
5. When the task is completed, you should say "Next Action": "None" in the json field.
|
216 |
+
"""
|
217 |
+
|
218 |
+
|
219 |
+
|
220 |
+
def _maybe_filter_to_n_most_recent_images(
|
221 |
+
messages: list[BetaMessageParam],
|
222 |
+
images_to_keep: int,
|
223 |
+
min_removal_threshold: int = 10,
|
224 |
+
):
|
225 |
+
"""
|
226 |
+
With the assumption that images are screenshots that are of diminishing value as
|
227 |
+
the conversation progresses, remove all but the final `images_to_keep` tool_result
|
228 |
+
images in place, with a chunk of min_removal_threshold to reduce the amount we
|
229 |
+
break the implicit prompt cache.
|
230 |
+
"""
|
231 |
+
if images_to_keep is None:
|
232 |
+
return messages
|
233 |
+
|
234 |
+
tool_result_blocks = cast(
|
235 |
+
list[ToolResultBlockParam],
|
236 |
+
[
|
237 |
+
item
|
238 |
+
for message in messages
|
239 |
+
for item in (
|
240 |
+
message["content"] if isinstance(message["content"], list) else []
|
241 |
+
)
|
242 |
+
if isinstance(item, dict) and item.get("type") == "tool_result"
|
243 |
+
],
|
244 |
+
)
|
245 |
+
|
246 |
+
total_images = sum(
|
247 |
+
1
|
248 |
+
for tool_result in tool_result_blocks
|
249 |
+
for content in tool_result.get("content", [])
|
250 |
+
if isinstance(content, dict) and content.get("type") == "image"
|
251 |
+
)
|
252 |
+
|
253 |
+
images_to_remove = total_images - images_to_keep
|
254 |
+
# for better cache behavior, we want to remove in chunks
|
255 |
+
images_to_remove -= images_to_remove % min_removal_threshold
|
256 |
+
|
257 |
+
for tool_result in tool_result_blocks:
|
258 |
+
if isinstance(tool_result.get("content"), list):
|
259 |
+
new_content = []
|
260 |
+
for content in tool_result.get("content", []):
|
261 |
+
if isinstance(content, dict) and content.get("type") == "image":
|
262 |
+
if images_to_remove > 0:
|
263 |
+
images_to_remove -= 1
|
264 |
+
continue
|
265 |
+
new_content.append(content)
|
266 |
+
tool_result["content"] = new_content
|
267 |
+
|
268 |
+
|
269 |
+
def _message_filter_callback(messages):
|
270 |
+
filtered_list = []
|
271 |
+
try:
|
272 |
+
for msg in messages:
|
273 |
+
if msg.get('role') in ['user']:
|
274 |
+
if not isinstance(msg["content"], list):
|
275 |
+
msg["content"] = [msg["content"]]
|
276 |
+
if isinstance(msg["content"][0], TextBlock):
|
277 |
+
filtered_list.append(str(msg["content"][0].text)) # User message
|
278 |
+
elif isinstance(msg["content"][0], str):
|
279 |
+
filtered_list.append(msg["content"][0]) # User message
|
280 |
+
else:
|
281 |
+
print("[_message_filter_callback]: drop message", msg)
|
282 |
+
continue
|
283 |
+
|
284 |
+
# elif msg.get('role') in ['assistant']:
|
285 |
+
# if isinstance(msg["content"][0], TextBlock):
|
286 |
+
# msg["content"][0] = str(msg["content"][0].text)
|
287 |
+
# elif isinstance(msg["content"][0], BetaTextBlock):
|
288 |
+
# msg["content"][0] = str(msg["content"][0].text)
|
289 |
+
# elif isinstance(msg["content"][0], BetaToolUseBlock):
|
290 |
+
# msg["content"][0] = str(msg['content'][0].input)
|
291 |
+
# elif isinstance(msg["content"][0], Dict) and msg["content"][0]["content"][-1]["type"] == "image":
|
292 |
+
# msg["content"][0] = f'<img src="data:image/png;base64,{msg["content"][0]["content"][-1]["source"]["data"]}">'
|
293 |
+
# else:
|
294 |
+
# print("[_message_filter_callback]: drop message", msg)
|
295 |
+
# continue
|
296 |
+
# filtered_list.append(msg["content"][0]) # User message
|
297 |
+
|
298 |
+
else:
|
299 |
+
print("[_message_filter_callback]: drop message", msg)
|
300 |
+
continue
|
301 |
+
|
302 |
+
except Exception as e:
|
303 |
+
print("[_message_filter_callback]: error", e)
|
304 |
+
|
305 |
+
return filtered_list
|
computer_use_demo/gui_agent/planner/local_vlm_planner.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import asyncio
|
3 |
+
import platform
|
4 |
+
from collections.abc import Callable
|
5 |
+
from datetime import datetime
|
6 |
+
from enum import StrEnum
|
7 |
+
from typing import Any, cast, Dict, Callable
|
8 |
+
|
9 |
+
from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex, APIResponse
|
10 |
+
from anthropic.types import TextBlock, ToolResultBlockParam
|
11 |
+
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock, BetaMessageParam
|
12 |
+
|
13 |
+
from computer_use_demo.tools.screen_capture import get_screenshot
|
14 |
+
from computer_use_demo.gui_agent.llm_utils.llm_utils import extract_data, encode_image
|
15 |
+
from computer_use_demo.tools.colorful_text import colorful_text_showui, colorful_text_vlm
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
19 |
+
from qwen_vl_utils import process_vision_info
|
20 |
+
|
21 |
+
SYSTEM_PROMPT = f"""<SYSTEM_CAPABILITY>
|
22 |
+
* You are utilizing a Windows system with internet access.
|
23 |
+
* The current date is {datetime.today().strftime('%A, %B %d, %Y')}.
|
24 |
+
</SYSTEM_CAPABILITY>
|
25 |
+
"""
|
26 |
+
|
27 |
+
class LocalVLMPlanner:
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
model: str,
|
31 |
+
provider: str,
|
32 |
+
system_prompt_suffix: str,
|
33 |
+
output_callback: Callable,
|
34 |
+
api_response_callback: Callable,
|
35 |
+
max_tokens: int = 4096,
|
36 |
+
only_n_most_recent_images: int | None = None,
|
37 |
+
selected_screen: int = 0,
|
38 |
+
print_usage: bool = True,
|
39 |
+
device: torch.device = torch.device("cpu"),
|
40 |
+
):
|
41 |
+
self.device = device
|
42 |
+
self.min_pixels = 256 * 28 * 28
|
43 |
+
self.max_pixels = 1344 * 28 * 28
|
44 |
+
|
45 |
+
if model == "qwen-vl-7b-instruct": # local model
|
46 |
+
self.model_name = "qwen-vl-7b-instruct"
|
47 |
+
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
|
48 |
+
# "Qwen/Qwen2-VL-7B-Instruct",
|
49 |
+
"./Qwen2-VL-7B-Instruct",
|
50 |
+
torch_dtype=torch.bfloat16,
|
51 |
+
device_map="cpu"
|
52 |
+
).to(self.device)
|
53 |
+
self.processor = AutoProcessor.from_pretrained(
|
54 |
+
"./Qwen2-VL-7B-Instruct",
|
55 |
+
min_pixels=self.min_pixels,
|
56 |
+
max_pixels=self.max_pixels
|
57 |
+
)
|
58 |
+
|
59 |
+
elif model == "qwen2-vl-2b-instruct":
|
60 |
+
self.model_name = "qwen2-vl-2b-instruct"
|
61 |
+
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
|
62 |
+
# "Qwen/Qwen2-VL-2B-Instruct",
|
63 |
+
"./Qwen2-VL-2B-Instruct",
|
64 |
+
torch_dtype=torch.bfloat16,
|
65 |
+
device_map="cpu"
|
66 |
+
).to(self.device)
|
67 |
+
self.processor = AutoProcessor.from_pretrained(
|
68 |
+
"./Qwen2-VL-2B-Instruct",
|
69 |
+
min_pixels=self.min_pixels,
|
70 |
+
max_pixels=self.max_pixels
|
71 |
+
)
|
72 |
+
else:
|
73 |
+
raise ValueError(f"Model {model} not supported")
|
74 |
+
|
75 |
+
self.provider = provider
|
76 |
+
self.system_prompt_suffix = system_prompt_suffix
|
77 |
+
self.api_response_callback = api_response_callback
|
78 |
+
self.max_tokens = max_tokens
|
79 |
+
self.only_n_most_recent_images = only_n_most_recent_images
|
80 |
+
self.selected_screen = selected_screen
|
81 |
+
self.output_callback = output_callback
|
82 |
+
self.system_prompt = self._get_system_prompt() + self.system_prompt_suffix
|
83 |
+
|
84 |
+
self.print_usage = print_usage
|
85 |
+
self.total_token_usage = 0
|
86 |
+
self.total_cost = 0
|
87 |
+
|
88 |
+
|
89 |
+
def __call__(self, messages: list):
|
90 |
+
|
91 |
+
# drop looping actions msg, byte image etc
|
92 |
+
planner_messages = _message_filter_callback(messages)
|
93 |
+
print(f"filtered_messages: {planner_messages}")
|
94 |
+
|
95 |
+
# Take a screenshot
|
96 |
+
screenshot, screenshot_path = get_screenshot(selected_screen=self.selected_screen)
|
97 |
+
screenshot_path = str(screenshot_path)
|
98 |
+
image_base64 = encode_image(screenshot_path)
|
99 |
+
self.output_callback(f'Screenshot for {colorful_text_vlm}:\n<img src="data:image/png;base64,{image_base64}">',
|
100 |
+
sender="bot")
|
101 |
+
|
102 |
+
if isinstance(planner_messages[-1], dict):
|
103 |
+
if not isinstance(planner_messages[-1]["content"], list):
|
104 |
+
planner_messages[-1]["content"] = [planner_messages[-1]["content"]]
|
105 |
+
planner_messages[-1]["content"].append(screenshot_path)
|
106 |
+
|
107 |
+
print(f"Sending messages to VLMPlanner: {planner_messages}")
|
108 |
+
|
109 |
+
messages_for_processor = [
|
110 |
+
{
|
111 |
+
"role": "system",
|
112 |
+
"content": [{"type": "text", "text": self.system_prompt}]
|
113 |
+
},
|
114 |
+
{
|
115 |
+
"role": "user",
|
116 |
+
"content": [
|
117 |
+
{"type": "image", "image": screenshot_path, "min_pixels": self.min_pixels, "max_pixels": self.max_pixels},
|
118 |
+
{"type": "text", "text": f"Task: {''.join(planner_messages)}"}
|
119 |
+
],
|
120 |
+
}]
|
121 |
+
|
122 |
+
text = self.processor.apply_chat_template(
|
123 |
+
messages_for_processor, tokenize=False, add_generation_prompt=True
|
124 |
+
)
|
125 |
+
image_inputs, video_inputs = process_vision_info(messages_for_processor)
|
126 |
+
|
127 |
+
inputs = self.processor(
|
128 |
+
text=[text],
|
129 |
+
images=image_inputs,
|
130 |
+
videos=video_inputs,
|
131 |
+
padding=True,
|
132 |
+
return_tensors="pt",
|
133 |
+
)
|
134 |
+
inputs = inputs.to(self.device)
|
135 |
+
|
136 |
+
generated_ids = self.model.generate(**inputs, max_new_tokens=128)
|
137 |
+
generated_ids_trimmed = [
|
138 |
+
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
139 |
+
]
|
140 |
+
vlm_response = self.processor.batch_decode(
|
141 |
+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
142 |
+
)[0]
|
143 |
+
|
144 |
+
print(f"VLMPlanner response: {vlm_response}")
|
145 |
+
|
146 |
+
vlm_response_json = extract_data(vlm_response, "json")
|
147 |
+
|
148 |
+
# vlm_plan_str = '\n'.join([f'{key}: {value}' for key, value in json.loads(response).items()])
|
149 |
+
vlm_plan_str = ""
|
150 |
+
for key, value in json.loads(vlm_response_json).items():
|
151 |
+
if key == "Thinking":
|
152 |
+
vlm_plan_str += f'{value}'
|
153 |
+
else:
|
154 |
+
vlm_plan_str += f'\n{key}: {value}'
|
155 |
+
|
156 |
+
self.output_callback(f"{colorful_text_vlm}:\n{vlm_plan_str}", sender="bot")
|
157 |
+
|
158 |
+
return vlm_response_json
|
159 |
+
|
160 |
+
|
161 |
+
def _api_response_callback(self, response: APIResponse):
|
162 |
+
self.api_response_callback(response)
|
163 |
+
|
164 |
+
|
165 |
+
def reformat_messages(self, messages: list):
|
166 |
+
pass
|
167 |
+
|
168 |
+
def _get_system_prompt(self):
|
169 |
+
os_name = platform.system()
|
170 |
+
return f"""
|
171 |
+
You are using an {os_name} device.
|
172 |
+
You are able to use a mouse and keyboard to interact with the computer based on the given task and screenshot.
|
173 |
+
You can only interact with the desktop GUI (no terminal or application menu access).
|
174 |
+
|
175 |
+
You may be given some history plan and actions, this is the response from the previous loop.
|
176 |
+
You should carefully consider your plan base on the task, screenshot, and history actions.
|
177 |
+
|
178 |
+
Your available "Next Action" only include:
|
179 |
+
- ENTER: Press an enter key.
|
180 |
+
- ESCAPE: Press an ESCAPE key.
|
181 |
+
- INPUT: Input a string of text.
|
182 |
+
- CLICK: Describe the ui element to be clicked.
|
183 |
+
- HOVER: Describe the ui element to be hovered.
|
184 |
+
- SCROLL: Scroll the screen, you must specify up or down.
|
185 |
+
- PRESS: Describe the ui element to be pressed.
|
186 |
+
|
187 |
+
Output format:
|
188 |
+
```json
|
189 |
+
{{
|
190 |
+
"Thinking": str, # describe your thoughts on how to achieve the task, choose one action from available actions at a time.
|
191 |
+
"Next Action": "action_type, action description" | "None" # one action at a time, describe it in short and precisely.
|
192 |
+
}}
|
193 |
+
```
|
194 |
+
|
195 |
+
One Example:
|
196 |
+
```json
|
197 |
+
{{
|
198 |
+
"Thinking": "I need to search and navigate to amazon.com.",
|
199 |
+
"Next Action": "CLICK 'Search Google or type a URL'."
|
200 |
+
}}
|
201 |
+
```
|
202 |
+
|
203 |
+
IMPORTANT NOTES:
|
204 |
+
1. Carefully observe the screenshot to understand the current state and read history actions.
|
205 |
+
2. You should only give a single action at a time. for example, INPUT text, and ENTER can't be in one Next Action.
|
206 |
+
3. Attach the text to Next Action, if there is text or any description for the button.
|
207 |
+
4. You should not include other actions, such as keyboard shortcuts.
|
208 |
+
5. When the task is completed, you should say "Next Action": "None" in the json field.
|
209 |
+
"""
|
210 |
+
|
211 |
+
|
212 |
+
|
213 |
+
def _message_filter_callback(messages):
|
214 |
+
filtered_list = []
|
215 |
+
try:
|
216 |
+
for msg in messages:
|
217 |
+
if msg.get('role') in ['user']:
|
218 |
+
if not isinstance(msg["content"], list):
|
219 |
+
msg["content"] = [msg["content"]]
|
220 |
+
if isinstance(msg["content"][0], TextBlock):
|
221 |
+
filtered_list.append(str(msg["content"][0].text)) # User message
|
222 |
+
elif isinstance(msg["content"][0], str):
|
223 |
+
filtered_list.append(msg["content"][0]) # User message
|
224 |
+
else:
|
225 |
+
print("[_message_filter_callback]: drop message", msg)
|
226 |
+
continue
|
227 |
+
|
228 |
+
else:
|
229 |
+
print("[_message_filter_callback]: drop message", msg)
|
230 |
+
continue
|
231 |
+
|
232 |
+
except Exception as e:
|
233 |
+
print("[_message_filter_callback]: error", e)
|
234 |
+
|
235 |
+
return filtered_list
|
computer_use_demo/loop.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Agentic sampling loop that calls the Anthropic API and local implementation of computer use tools.
|
3 |
+
"""
|
4 |
+
import time
|
5 |
+
import json
|
6 |
+
from collections.abc import Callable
|
7 |
+
from enum import StrEnum
|
8 |
+
|
9 |
+
from anthropic import APIResponse
|
10 |
+
from anthropic.types.beta import BetaContentBlock, BetaMessage, BetaMessageParam
|
11 |
+
from computer_use_demo.tools import ToolResult
|
12 |
+
|
13 |
+
|
14 |
+
import torch
|
15 |
+
|
16 |
+
from computer_use_demo.gui_agent.planner.anthropic_agent import AnthropicActor
|
17 |
+
from computer_use_demo.executor.anthropic_executor import AnthropicExecutor
|
18 |
+
from computer_use_demo.gui_agent.planner.api_vlm_planner import APIVLMPlanner
|
19 |
+
from computer_use_demo.gui_agent.planner.local_vlm_planner import LocalVLMPlanner
|
20 |
+
from computer_use_demo.gui_agent.actor.showui_agent import ShowUIActor
|
21 |
+
from computer_use_demo.executor.showui_executor import ShowUIExecutor
|
22 |
+
from computer_use_demo.gui_agent.actor.uitars_agent import UITARS_Actor
|
23 |
+
from computer_use_demo.tools.colorful_text import colorful_text_showui, colorful_text_vlm
|
24 |
+
from computer_use_demo.tools.screen_capture import get_screenshot
|
25 |
+
from computer_use_demo.gui_agent.llm_utils.oai import encode_image
|
26 |
+
|
27 |
+
from computer_use_demo.tools.logger import logger
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
class APIProvider(StrEnum):
|
32 |
+
ANTHROPIC = "anthropic"
|
33 |
+
BEDROCK = "bedrock"
|
34 |
+
VERTEX = "vertex"
|
35 |
+
OPENAI = "openai"
|
36 |
+
QWEN = "qwen"
|
37 |
+
SSH = "ssh"
|
38 |
+
|
39 |
+
|
40 |
+
PROVIDER_TO_DEFAULT_MODEL_NAME: dict[APIProvider, str] = {
|
41 |
+
APIProvider.ANTHROPIC: "claude-3-5-sonnet-20241022",
|
42 |
+
APIProvider.BEDROCK: "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
43 |
+
APIProvider.VERTEX: "claude-3-5-sonnet-v2@20241022",
|
44 |
+
APIProvider.OPENAI: "gpt-4o",
|
45 |
+
APIProvider.QWEN: "qwen2vl",
|
46 |
+
APIProvider.SSH: "qwen2-vl-2b",
|
47 |
+
}
|
48 |
+
|
49 |
+
|
50 |
+
def sampling_loop_sync(
|
51 |
+
*,
|
52 |
+
planner_model: str,
|
53 |
+
planner_provider: APIProvider | None,
|
54 |
+
actor_model: str,
|
55 |
+
actor_provider: APIProvider | None,
|
56 |
+
system_prompt_suffix: str,
|
57 |
+
messages: list[BetaMessageParam],
|
58 |
+
output_callback: Callable[[BetaContentBlock], None],
|
59 |
+
tool_output_callback: Callable[[ToolResult, str], None],
|
60 |
+
api_response_callback: Callable[[APIResponse[BetaMessage]], None],
|
61 |
+
api_key: str,
|
62 |
+
only_n_most_recent_images: int | None = None,
|
63 |
+
max_tokens: int = 4096,
|
64 |
+
selected_screen: int = 0,
|
65 |
+
showui_max_pixels: int = 1344,
|
66 |
+
showui_awq_4bit: bool = False,
|
67 |
+
ui_tars_url: str = ""
|
68 |
+
):
|
69 |
+
"""
|
70 |
+
Synchronous agentic sampling loop for the assistant/tool interaction of computer use.
|
71 |
+
"""
|
72 |
+
|
73 |
+
# ---------------------------
|
74 |
+
# Initialize Planner
|
75 |
+
# ---------------------------
|
76 |
+
if planner_model == "claude-3-5-sonnet-20241022":
|
77 |
+
# Register Actor and Executor
|
78 |
+
actor = AnthropicActor(
|
79 |
+
model=planner_model,
|
80 |
+
provider=actor_provider,
|
81 |
+
system_prompt_suffix=system_prompt_suffix,
|
82 |
+
api_key=api_key,
|
83 |
+
api_response_callback=api_response_callback,
|
84 |
+
max_tokens=max_tokens,
|
85 |
+
only_n_most_recent_images=only_n_most_recent_images,
|
86 |
+
selected_screen=selected_screen
|
87 |
+
)
|
88 |
+
|
89 |
+
executor = AnthropicExecutor(
|
90 |
+
output_callback=output_callback,
|
91 |
+
tool_output_callback=tool_output_callback,
|
92 |
+
selected_screen=selected_screen
|
93 |
+
)
|
94 |
+
|
95 |
+
loop_mode = "unified"
|
96 |
+
|
97 |
+
elif planner_model in ["gpt-4o", "gpt-4o-mini", "qwen2-vl-max"]:
|
98 |
+
|
99 |
+
if torch.cuda.is_available(): device = torch.device("cuda")
|
100 |
+
elif torch.backends.mps.is_available(): device = torch.device("mps")
|
101 |
+
else: device = torch.device("cpu") # support: 'cpu', 'mps', 'cuda'
|
102 |
+
logger.info(f"Model inited on device: {device}.")
|
103 |
+
|
104 |
+
planner = APIVLMPlanner(
|
105 |
+
model=planner_model,
|
106 |
+
provider=planner_provider,
|
107 |
+
system_prompt_suffix=system_prompt_suffix,
|
108 |
+
api_key=api_key,
|
109 |
+
api_response_callback=api_response_callback,
|
110 |
+
selected_screen=selected_screen,
|
111 |
+
output_callback=output_callback,
|
112 |
+
device=device
|
113 |
+
)
|
114 |
+
loop_mode = "planner + actor"
|
115 |
+
|
116 |
+
elif planner_model == "qwen2-vl-7b-instruct":
|
117 |
+
planner = LocalVLMPlanner(
|
118 |
+
model=planner_model,
|
119 |
+
provider=planner_provider,
|
120 |
+
system_prompt_suffix=system_prompt_suffix,
|
121 |
+
api_key=api_key,
|
122 |
+
api_response_callback=api_response_callback,
|
123 |
+
selected_screen=selected_screen,
|
124 |
+
output_callback=output_callback,
|
125 |
+
device=device
|
126 |
+
)
|
127 |
+
loop_mode = "planner + actor"
|
128 |
+
elif "ssh" in planner_model:
|
129 |
+
if torch.cuda.is_available(): device = torch.device("cuda")
|
130 |
+
elif torch.backends.mps.is_available(): device = torch.device("mps")
|
131 |
+
else: device = torch.device("cpu") # support: 'cpu', 'mps', 'cuda'
|
132 |
+
logger.info(f"Model inited on device: {device}.")
|
133 |
+
planner = APIVLMPlanner(
|
134 |
+
model=planner_model,
|
135 |
+
provider=planner_provider,
|
136 |
+
system_prompt_suffix=system_prompt_suffix,
|
137 |
+
api_key=api_key,
|
138 |
+
api_response_callback=api_response_callback,
|
139 |
+
selected_screen=selected_screen,
|
140 |
+
output_callback=output_callback,
|
141 |
+
device=device
|
142 |
+
)
|
143 |
+
loop_mode = "planner + actor"
|
144 |
+
else:
|
145 |
+
logger.error(f"Planner Model {planner_model} not supported")
|
146 |
+
raise ValueError(f"Planner Model {planner_model} not supported")
|
147 |
+
|
148 |
+
|
149 |
+
# ---------------------------
|
150 |
+
# Initialize Actor
|
151 |
+
# ---------------------------
|
152 |
+
if actor_model == "ShowUI":
|
153 |
+
if showui_awq_4bit:
|
154 |
+
showui_model_path = "./showui-2b-awq-4bit/"
|
155 |
+
else:
|
156 |
+
showui_model_path = "./showui-2b/"
|
157 |
+
|
158 |
+
actor = ShowUIActor(
|
159 |
+
model_path=showui_model_path,
|
160 |
+
device=device,
|
161 |
+
split='desktop', # 'desktop' or 'phone'
|
162 |
+
selected_screen=selected_screen,
|
163 |
+
output_callback=output_callback,
|
164 |
+
max_pixels=showui_max_pixels,
|
165 |
+
awq_4bit=showui_awq_4bit
|
166 |
+
)
|
167 |
+
|
168 |
+
executor = ShowUIExecutor(
|
169 |
+
output_callback=output_callback,
|
170 |
+
tool_output_callback=tool_output_callback,
|
171 |
+
selected_screen=selected_screen
|
172 |
+
)
|
173 |
+
elif actor_model == "UI-TARS":
|
174 |
+
actor = UITARS_Actor(
|
175 |
+
ui_tars_url=ui_tars_url,
|
176 |
+
output_callback=output_callback,
|
177 |
+
selected_screen=selected_screen
|
178 |
+
)
|
179 |
+
|
180 |
+
else:
|
181 |
+
raise ValueError(f"Actor Model {actor_model} not supported")
|
182 |
+
|
183 |
+
|
184 |
+
tool_result_content = None
|
185 |
+
showui_loop_count = 0
|
186 |
+
|
187 |
+
logger.info(f"Start the message loop. User messages: {messages}")
|
188 |
+
|
189 |
+
if loop_mode == "unified":
|
190 |
+
# ------------------------------
|
191 |
+
# Unified loop: repeatedly call actor -> executor -> check tool_result -> maybe end
|
192 |
+
# ------------------------------
|
193 |
+
while True:
|
194 |
+
# Call the actor with current messages
|
195 |
+
response = actor(messages=messages)
|
196 |
+
|
197 |
+
# Let the executor process that response, yielding any intermediate messages
|
198 |
+
for message, tool_result_content in executor(response, messages):
|
199 |
+
yield message
|
200 |
+
|
201 |
+
# If executor didn't produce further content, we're done
|
202 |
+
if not tool_result_content:
|
203 |
+
return messages
|
204 |
+
|
205 |
+
# If there is more tool content, treat that as user input
|
206 |
+
messages.append({
|
207 |
+
"content": tool_result_content,
|
208 |
+
"role": "user"
|
209 |
+
})
|
210 |
+
|
211 |
+
elif loop_mode == "planner + actor":
|
212 |
+
# ------------------------------------------------------
|
213 |
+
# Planner + actor loop:
|
214 |
+
# 1) planner => get next_action
|
215 |
+
# 2) If no next_action -> end
|
216 |
+
# 3) Otherwise actor => executor
|
217 |
+
# 4) repeat
|
218 |
+
# ------------------------------------------------------
|
219 |
+
while True:
|
220 |
+
# Step 1: Planner (VLM) response
|
221 |
+
vlm_response = planner(messages=messages)
|
222 |
+
|
223 |
+
# Step 2: Extract the "Next Action" from the planner output
|
224 |
+
next_action = json.loads(vlm_response).get("Next Action")
|
225 |
+
|
226 |
+
# Yield the next_action string, in case the UI or logs want to show it
|
227 |
+
yield next_action
|
228 |
+
|
229 |
+
# Step 3: Check if there are no further actions
|
230 |
+
if not next_action or next_action in ("None", ""):
|
231 |
+
final_sc, final_sc_path = get_screenshot(selected_screen=selected_screen)
|
232 |
+
final_image_b64 = encode_image(str(final_sc_path))
|
233 |
+
|
234 |
+
output_callback(
|
235 |
+
(
|
236 |
+
f"No more actions from {colorful_text_vlm}. End of task. Final State:\n"
|
237 |
+
f'<img src="data:image/png;base64,{final_image_b64}">'
|
238 |
+
),
|
239 |
+
sender="bot"
|
240 |
+
)
|
241 |
+
yield None
|
242 |
+
break
|
243 |
+
|
244 |
+
# Step 4: Output an action message
|
245 |
+
output_callback(
|
246 |
+
f"{colorful_text_vlm} sending action to {colorful_text_showui}:\n{next_action}",
|
247 |
+
sender="bot"
|
248 |
+
)
|
249 |
+
|
250 |
+
# Step 5: Actor response
|
251 |
+
actor_response = actor(messages=next_action)
|
252 |
+
yield actor_response
|
253 |
+
|
254 |
+
# Step 6: Execute the actor response
|
255 |
+
for message, tool_result_content in executor(actor_response, messages):
|
256 |
+
time.sleep(0.5) # optional small delay
|
257 |
+
yield message
|
258 |
+
|
259 |
+
# Step 7: Update conversation with embedding history of plan and actions
|
260 |
+
messages.append({
|
261 |
+
"role": "user",
|
262 |
+
"content": [
|
263 |
+
"History plan:" + str(json.loads(vlm_response)),
|
264 |
+
"History actions:" + str(actor_response["content"])
|
265 |
+
]
|
266 |
+
})
|
267 |
+
|
268 |
+
logger.info(
|
269 |
+
f"End of loop {showui_loop_count + 1}. "
|
270 |
+
f"Messages: {str(messages)[:100000]}. "
|
271 |
+
f"Total cost: $USD{planner.total_cost:.5f}"
|
272 |
+
)
|
273 |
+
|
274 |
+
|
275 |
+
# Increment loop counter
|
276 |
+
showui_loop_count += 1
|
computer_use_demo/remote_inference.py
ADDED
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from contextlib import asynccontextmanager
|
2 |
+
from fastapi import FastAPI, HTTPException
|
3 |
+
from fastapi.middleware.cors import CORSMiddleware
|
4 |
+
from pydantic import BaseModel, field_validator
|
5 |
+
from typing import Optional, List, Union, Dict, Any
|
6 |
+
import torch
|
7 |
+
from transformers import (
|
8 |
+
Qwen2_5_VLForConditionalGeneration,
|
9 |
+
Qwen2VLForConditionalGeneration,
|
10 |
+
AutoProcessor,
|
11 |
+
BitsAndBytesConfig
|
12 |
+
)
|
13 |
+
from qwen_vl_utils import process_vision_info
|
14 |
+
import uvicorn
|
15 |
+
import json
|
16 |
+
from datetime import datetime
|
17 |
+
import logging
|
18 |
+
import time
|
19 |
+
import psutil
|
20 |
+
import GPUtil
|
21 |
+
import base64
|
22 |
+
from PIL import Image
|
23 |
+
import io
|
24 |
+
import os
|
25 |
+
import threading
|
26 |
+
|
27 |
+
# Set environment variables to disable compilation cache and avoid CUDA kernel issues
|
28 |
+
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
|
29 |
+
os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0" # Compatible with A5000
|
30 |
+
|
31 |
+
# Model configuration
|
32 |
+
MODELS = {
|
33 |
+
"Qwen2.5-VL-7B-Instruct": {
|
34 |
+
"path": "Qwen/Qwen2.5-VL-7B-Instruct",
|
35 |
+
"model_class": Qwen2_5_VLForConditionalGeneration,
|
36 |
+
},
|
37 |
+
"Qwen2-VL-7B-Instruct": {
|
38 |
+
"path": "Qwen/Qwen2-VL-7B-Instruct",
|
39 |
+
"model_class": Qwen2VLForConditionalGeneration,
|
40 |
+
},
|
41 |
+
"Qwen2-VL-2B-Instruct": {
|
42 |
+
"path": "Qwen/Qwen2-VL-2B-Instruct",
|
43 |
+
"model_class": Qwen2VLForConditionalGeneration,
|
44 |
+
}
|
45 |
+
}
|
46 |
+
|
47 |
+
# Configure logging
|
48 |
+
logging.basicConfig(
|
49 |
+
level=logging.INFO,
|
50 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
51 |
+
)
|
52 |
+
logger = logging.getLogger(__name__)
|
53 |
+
|
54 |
+
# Global variables
|
55 |
+
models = {}
|
56 |
+
processors = {}
|
57 |
+
model_locks = {} # Thread locks for model loading
|
58 |
+
last_used = {} # Record last use time of models
|
59 |
+
|
60 |
+
# Set default CUDA device
|
61 |
+
if torch.cuda.is_available():
|
62 |
+
# Get GPU information and select the device with maximum memory
|
63 |
+
gpus = GPUtil.getGPUs()
|
64 |
+
if gpus:
|
65 |
+
max_memory_gpu = max(gpus, key=lambda g: g.memoryTotal)
|
66 |
+
selected_device = max_memory_gpu.id
|
67 |
+
torch.cuda.set_device(selected_device)
|
68 |
+
device = torch.device(f"cuda:{selected_device}")
|
69 |
+
logger.info(f"Selected GPU {selected_device} ({max_memory_gpu.name}) with {max_memory_gpu.memoryTotal}MB memory")
|
70 |
+
else:
|
71 |
+
device = torch.device("cuda:0")
|
72 |
+
else:
|
73 |
+
device = torch.device("cpu")
|
74 |
+
logger.info(f"Using device: {device}")
|
75 |
+
|
76 |
+
class ImageURL(BaseModel):
|
77 |
+
url: str
|
78 |
+
|
79 |
+
class MessageContent(BaseModel):
|
80 |
+
type: str
|
81 |
+
text: Optional[str] = None
|
82 |
+
image_url: Optional[Dict[str, str]] = None
|
83 |
+
|
84 |
+
@field_validator('type')
|
85 |
+
@classmethod
|
86 |
+
def validate_type(cls, v: str) -> str:
|
87 |
+
if v not in ['text', 'image_url']:
|
88 |
+
raise ValueError(f"Invalid content type: {v}")
|
89 |
+
return v
|
90 |
+
|
91 |
+
class ChatMessage(BaseModel):
|
92 |
+
role: str
|
93 |
+
content: Union[str, List[MessageContent]]
|
94 |
+
|
95 |
+
@field_validator('role')
|
96 |
+
@classmethod
|
97 |
+
def validate_role(cls, v: str) -> str:
|
98 |
+
if v not in ['system', 'user', 'assistant']:
|
99 |
+
raise ValueError(f"Invalid role: {v}")
|
100 |
+
return v
|
101 |
+
|
102 |
+
@field_validator('content')
|
103 |
+
@classmethod
|
104 |
+
def validate_content(cls, v: Union[str, List[Any]]) -> Union[str, List[MessageContent]]:
|
105 |
+
if isinstance(v, str):
|
106 |
+
return v
|
107 |
+
if isinstance(v, list):
|
108 |
+
return [MessageContent(**item) if isinstance(item, dict) else item for item in v]
|
109 |
+
raise ValueError("Content must be either a string or a list of content items")
|
110 |
+
|
111 |
+
class ChatCompletionRequest(BaseModel):
|
112 |
+
model: str
|
113 |
+
messages: List[ChatMessage]
|
114 |
+
temperature: Optional[float] = 0.7
|
115 |
+
top_p: Optional[float] = 0.95
|
116 |
+
max_tokens: Optional[int] = 2048
|
117 |
+
stream: Optional[bool] = False
|
118 |
+
response_format: Optional[Dict[str, str]] = None
|
119 |
+
|
120 |
+
class ChatCompletionResponse(BaseModel):
|
121 |
+
id: str
|
122 |
+
object: str
|
123 |
+
created: int
|
124 |
+
model: str
|
125 |
+
choices: List[Dict[str, Any]]
|
126 |
+
usage: Dict[str, int]
|
127 |
+
|
128 |
+
class ModelCard(BaseModel):
|
129 |
+
id: str
|
130 |
+
created: int
|
131 |
+
owned_by: str
|
132 |
+
permission: List[Dict[str, Any]] = []
|
133 |
+
root: Optional[str] = None
|
134 |
+
parent: Optional[str] = None
|
135 |
+
capabilities: Optional[Dict[str, bool]] = None
|
136 |
+
context_window: Optional[int] = None
|
137 |
+
max_tokens: Optional[int] = None
|
138 |
+
|
139 |
+
class ModelList(BaseModel):
|
140 |
+
object: str = "list"
|
141 |
+
data: List[ModelCard]
|
142 |
+
|
143 |
+
def process_base64_image(base64_string: str) -> Image.Image:
|
144 |
+
"""Process base64 image data and return PIL Image"""
|
145 |
+
try:
|
146 |
+
# Remove data URL prefix if present
|
147 |
+
if 'base64,' in base64_string:
|
148 |
+
base64_string = base64_string.split('base64,')[1]
|
149 |
+
|
150 |
+
image_data = base64.b64decode(base64_string)
|
151 |
+
image = Image.open(io.BytesIO(image_data))
|
152 |
+
|
153 |
+
# Convert to RGB if necessary
|
154 |
+
if image.mode not in ('RGB', 'L'):
|
155 |
+
image = image.convert('RGB')
|
156 |
+
|
157 |
+
return image
|
158 |
+
except Exception as e:
|
159 |
+
logger.error(f"Error processing base64 image: {str(e)}")
|
160 |
+
raise ValueError(f"Invalid base64 image data: {str(e)}")
|
161 |
+
|
162 |
+
def log_system_info():
|
163 |
+
"""Log system resource information"""
|
164 |
+
try:
|
165 |
+
cpu_percent = psutil.cpu_percent(interval=1)
|
166 |
+
memory = psutil.virtual_memory()
|
167 |
+
gpu_info = []
|
168 |
+
if torch.cuda.is_available():
|
169 |
+
for gpu in GPUtil.getGPUs():
|
170 |
+
gpu_info.append({
|
171 |
+
'id': gpu.id,
|
172 |
+
'name': gpu.name,
|
173 |
+
'load': f"{gpu.load*100}%",
|
174 |
+
'memory_used': f"{gpu.memoryUsed}MB/{gpu.memoryTotal}MB",
|
175 |
+
'temperature': f"{gpu.temperature}°C"
|
176 |
+
})
|
177 |
+
logger.info(f"System Info - CPU: {cpu_percent}%, RAM: {memory.percent}%, "
|
178 |
+
f"Available RAM: {memory.available/1024/1024/1024:.1f}GB")
|
179 |
+
if gpu_info:
|
180 |
+
logger.info(f"GPU Info: {gpu_info}")
|
181 |
+
except Exception as e:
|
182 |
+
logger.warning(f"Failed to log system info: {str(e)}")
|
183 |
+
|
184 |
+
def get_or_initialize_model(model_name: str):
|
185 |
+
"""Get or initialize a model if not already loaded"""
|
186 |
+
global models, processors, model_locks, last_used
|
187 |
+
|
188 |
+
if model_name not in MODELS:
|
189 |
+
available_models = list(MODELS.keys())
|
190 |
+
raise ValueError(f"Unsupported model: {model_name}\nAvailable models: {available_models}")
|
191 |
+
|
192 |
+
# Initialize lock for the model (if not already done)
|
193 |
+
if model_name not in model_locks:
|
194 |
+
model_locks[model_name] = threading.Lock()
|
195 |
+
|
196 |
+
with model_locks[model_name]:
|
197 |
+
if model_name not in models or model_name not in processors:
|
198 |
+
try:
|
199 |
+
start_time = time.time()
|
200 |
+
logger.info(f"Starting {model_name} initialization...")
|
201 |
+
log_system_info()
|
202 |
+
|
203 |
+
model_config = MODELS[model_name]
|
204 |
+
|
205 |
+
# Configure 8-bit quantization
|
206 |
+
quantization_config = BitsAndBytesConfig(
|
207 |
+
load_in_8bit=True,
|
208 |
+
bnb_4bit_compute_dtype=torch.float16,
|
209 |
+
bnb_4bit_use_double_quant=False,
|
210 |
+
bnb_4bit_quant_type="nf4",
|
211 |
+
)
|
212 |
+
|
213 |
+
logger.info(f"Loading {model_name} with 8-bit quantization...")
|
214 |
+
model = model_config["model_class"].from_pretrained(
|
215 |
+
model_config["path"],
|
216 |
+
quantization_config=quantization_config,
|
217 |
+
device_map={"": device.index if device.type == "cuda" else "cpu"},
|
218 |
+
local_files_only=False
|
219 |
+
).eval()
|
220 |
+
|
221 |
+
processor = AutoProcessor.from_pretrained(
|
222 |
+
model_config["path"],
|
223 |
+
local_files_only=False
|
224 |
+
)
|
225 |
+
|
226 |
+
models[model_name] = model
|
227 |
+
processors[model_name] = processor
|
228 |
+
|
229 |
+
end_time = time.time()
|
230 |
+
logger.info(f"Model {model_name} initialized in {end_time - start_time:.2f} seconds")
|
231 |
+
log_system_info()
|
232 |
+
|
233 |
+
except Exception as e:
|
234 |
+
logger.error(f"Model initialization error for {model_name}: {str(e)}", exc_info=True)
|
235 |
+
raise RuntimeError(f"Failed to initialize model {model_name}: {str(e)}")
|
236 |
+
|
237 |
+
# Update last use time
|
238 |
+
last_used[model_name] = time.time()
|
239 |
+
|
240 |
+
return models[model_name], processors[model_name]
|
241 |
+
|
242 |
+
@asynccontextmanager
|
243 |
+
async def lifespan(app: FastAPI):
|
244 |
+
logger.info("Starting application initialization...")
|
245 |
+
try:
|
246 |
+
yield
|
247 |
+
finally:
|
248 |
+
logger.info("Shutting down application...")
|
249 |
+
global models, processors
|
250 |
+
for model_name, model in models.items():
|
251 |
+
try:
|
252 |
+
del model
|
253 |
+
logger.info(f"Model {model_name} unloaded")
|
254 |
+
except Exception as e:
|
255 |
+
logger.error(f"Error during cleanup of {model_name}: {str(e)}")
|
256 |
+
|
257 |
+
if torch.cuda.is_available():
|
258 |
+
torch.cuda.empty_cache()
|
259 |
+
logger.info("CUDA cache cleared")
|
260 |
+
|
261 |
+
models = {}
|
262 |
+
processors = {}
|
263 |
+
logger.info("Shutdown complete")
|
264 |
+
|
265 |
+
app = FastAPI(
|
266 |
+
title="Qwen2.5-VL API",
|
267 |
+
description="OpenAI-compatible API for Qwen2.5-VL vision-language model",
|
268 |
+
version="1.0.0",
|
269 |
+
lifespan=lifespan
|
270 |
+
)
|
271 |
+
|
272 |
+
app.add_middleware(
|
273 |
+
CORSMiddleware,
|
274 |
+
allow_origins=["*"],
|
275 |
+
allow_credentials=True,
|
276 |
+
allow_methods=["*"],
|
277 |
+
allow_headers=["*"],
|
278 |
+
)
|
279 |
+
|
280 |
+
@app.get("/v1/models", response_model=ModelList)
|
281 |
+
async def list_models():
|
282 |
+
"""List available models"""
|
283 |
+
model_cards = []
|
284 |
+
for model_name in MODELS.keys():
|
285 |
+
model_cards.append(
|
286 |
+
ModelCard(
|
287 |
+
id=model_name,
|
288 |
+
created=1709251200,
|
289 |
+
owned_by="Qwen",
|
290 |
+
permission=[{
|
291 |
+
"id": f"modelperm-{model_name}",
|
292 |
+
"created": 1709251200,
|
293 |
+
"allow_create_engine": False,
|
294 |
+
"allow_sampling": True,
|
295 |
+
"allow_logprobs": True,
|
296 |
+
"allow_search_indices": False,
|
297 |
+
"allow_view": True,
|
298 |
+
"allow_fine_tuning": False,
|
299 |
+
"organization": "*",
|
300 |
+
"group": None,
|
301 |
+
"is_blocking": False
|
302 |
+
}],
|
303 |
+
capabilities={
|
304 |
+
"vision": True,
|
305 |
+
"chat": True,
|
306 |
+
"embeddings": False,
|
307 |
+
"text_completion": True
|
308 |
+
},
|
309 |
+
context_window=4096,
|
310 |
+
max_tokens=2048
|
311 |
+
)
|
312 |
+
)
|
313 |
+
return ModelList(data=model_cards)
|
314 |
+
|
315 |
+
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
316 |
+
async def chat_completions(request: ChatCompletionRequest):
|
317 |
+
"""Handle chat completion requests with vision support"""
|
318 |
+
try:
|
319 |
+
# Get or initialize requested model
|
320 |
+
model, processor = get_or_initialize_model(request.model)
|
321 |
+
|
322 |
+
request_start_time = time.time()
|
323 |
+
logger.info(f"Received chat completion request for model: {request.model}")
|
324 |
+
logger.info(f"Request content: {request.model_dump_json()}")
|
325 |
+
|
326 |
+
messages = []
|
327 |
+
for msg in request.messages:
|
328 |
+
if isinstance(msg.content, str):
|
329 |
+
messages.append({"role": msg.role, "content": msg.content})
|
330 |
+
else:
|
331 |
+
processed_content = []
|
332 |
+
for content_item in msg.content:
|
333 |
+
if content_item.type == "text":
|
334 |
+
processed_content.append({
|
335 |
+
"type": "text",
|
336 |
+
"text": content_item.text
|
337 |
+
})
|
338 |
+
elif content_item.type == "image_url":
|
339 |
+
if "url" in content_item.image_url:
|
340 |
+
if content_item.image_url["url"].startswith("data:image"):
|
341 |
+
processed_content.append({
|
342 |
+
"type": "image",
|
343 |
+
"image": process_base64_image(content_item.image_url["url"])
|
344 |
+
})
|
345 |
+
messages.append({"role": msg.role, "content": processed_content})
|
346 |
+
|
347 |
+
text = processor.apply_chat_template(
|
348 |
+
messages,
|
349 |
+
tokenize=False,
|
350 |
+
add_generation_prompt=True
|
351 |
+
)
|
352 |
+
|
353 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
354 |
+
|
355 |
+
# Ensure input data is on the correct device
|
356 |
+
inputs = processor(
|
357 |
+
text=[text],
|
358 |
+
images=image_inputs,
|
359 |
+
videos=video_inputs,
|
360 |
+
padding=True,
|
361 |
+
return_tensors="pt"
|
362 |
+
)
|
363 |
+
|
364 |
+
# Move all tensors to specified device
|
365 |
+
input_tensors = {k: v.to(device) if hasattr(v, 'to') else v for k, v in inputs.items()}
|
366 |
+
|
367 |
+
with torch.inference_mode():
|
368 |
+
generated_ids = model.generate(
|
369 |
+
**input_tensors,
|
370 |
+
max_new_tokens=request.max_tokens,
|
371 |
+
temperature=request.temperature,
|
372 |
+
top_p=request.top_p,
|
373 |
+
pad_token_id=processor.tokenizer.pad_token_id,
|
374 |
+
eos_token_id=processor.tokenizer.eos_token_id
|
375 |
+
)
|
376 |
+
|
377 |
+
# Get input length and trim generated IDs
|
378 |
+
input_length = input_tensors['input_ids'].shape[1]
|
379 |
+
generated_ids_trimmed = generated_ids[:, input_length:]
|
380 |
+
|
381 |
+
response = processor.batch_decode(
|
382 |
+
generated_ids_trimmed,
|
383 |
+
skip_special_tokens=True,
|
384 |
+
clean_up_tokenization_spaces=False
|
385 |
+
)[0]
|
386 |
+
|
387 |
+
if request.response_format and request.response_format.get("type") == "json_object":
|
388 |
+
try:
|
389 |
+
if response.startswith('```'):
|
390 |
+
response = '\n'.join(response.split('\n')[1:-1])
|
391 |
+
if response.startswith('json'):
|
392 |
+
response = response[4:].lstrip()
|
393 |
+
content = json.loads(response)
|
394 |
+
response = json.dumps(content)
|
395 |
+
except json.JSONDecodeError as e:
|
396 |
+
logger.error(f"JSON parsing error: {str(e)}")
|
397 |
+
raise HTTPException(status_code=400, detail=f"Invalid JSON response: {str(e)}")
|
398 |
+
|
399 |
+
total_time = time.time() - request_start_time
|
400 |
+
logger.info(f"Request completed in {total_time:.2f} seconds")
|
401 |
+
|
402 |
+
return ChatCompletionResponse(
|
403 |
+
id=f"chatcmpl-{datetime.now().strftime('%Y%m%d%H%M%S')}",
|
404 |
+
object="chat.completion",
|
405 |
+
created=int(datetime.now().timestamp()),
|
406 |
+
model=request.model,
|
407 |
+
choices=[{
|
408 |
+
"index": 0,
|
409 |
+
"message": {
|
410 |
+
"role": "assistant",
|
411 |
+
"content": response
|
412 |
+
},
|
413 |
+
"finish_reason": "stop"
|
414 |
+
}],
|
415 |
+
usage={
|
416 |
+
"prompt_tokens": input_length,
|
417 |
+
"completion_tokens": len(generated_ids_trimmed[0]),
|
418 |
+
"total_tokens": input_length + len(generated_ids_trimmed[0])
|
419 |
+
}
|
420 |
+
)
|
421 |
+
except Exception as e:
|
422 |
+
logger.error(f"Request error: {str(e)}", exc_info=True)
|
423 |
+
if isinstance(e, HTTPException):
|
424 |
+
raise
|
425 |
+
raise HTTPException(status_code=500, detail=str(e))
|
426 |
+
|
427 |
+
@app.get("/health")
|
428 |
+
async def health_check():
|
429 |
+
"""Health check endpoint"""
|
430 |
+
log_system_info()
|
431 |
+
return {
|
432 |
+
"status": "healthy",
|
433 |
+
"loaded_models": list(models.keys()),
|
434 |
+
"device": str(device),
|
435 |
+
"cuda_available": torch.cuda.is_available(),
|
436 |
+
"cuda_device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
|
437 |
+
"timestamp": datetime.now().isoformat()
|
438 |
+
}
|
439 |
+
|
440 |
+
@app.get("/model_status")
|
441 |
+
async def model_status():
|
442 |
+
"""Get the status of all models"""
|
443 |
+
status = {}
|
444 |
+
for model_name in MODELS:
|
445 |
+
status[model_name] = {
|
446 |
+
"loaded": model_name in models,
|
447 |
+
"last_used": last_used.get(model_name, None),
|
448 |
+
"available": model_name in MODELS
|
449 |
+
}
|
450 |
+
return status
|
451 |
+
|
452 |
+
if __name__ == "__main__":
|
453 |
+
uvicorn.run(app, host="0.0.0.0", port=9192)
|
computer_use_demo/tools/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base import CLIResult, ToolResult
|
2 |
+
from .bash import BashTool
|
3 |
+
from .collection import ToolCollection
|
4 |
+
from .computer import ComputerTool
|
5 |
+
from .edit import EditTool
|
6 |
+
from .screen_capture import get_screenshot
|
7 |
+
|
8 |
+
__ALL__ = [
|
9 |
+
BashTool,
|
10 |
+
CLIResult,
|
11 |
+
ComputerTool,
|
12 |
+
EditTool,
|
13 |
+
ToolCollection,
|
14 |
+
ToolResult,
|
15 |
+
get_screenshot,
|
16 |
+
]
|
computer_use_demo/tools/base.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABCMeta, abstractmethod
|
2 |
+
from dataclasses import dataclass, fields, replace
|
3 |
+
from typing import Any
|
4 |
+
|
5 |
+
from anthropic.types.beta import BetaToolUnionParam
|
6 |
+
|
7 |
+
|
8 |
+
class BaseAnthropicTool(metaclass=ABCMeta):
|
9 |
+
"""Abstract base class for Anthropic-defined tools."""
|
10 |
+
|
11 |
+
@abstractmethod
|
12 |
+
def __call__(self, **kwargs) -> Any:
|
13 |
+
"""Executes the tool with the given arguments."""
|
14 |
+
...
|
15 |
+
|
16 |
+
@abstractmethod
|
17 |
+
def to_params(
|
18 |
+
self,
|
19 |
+
) -> BetaToolUnionParam:
|
20 |
+
raise NotImplementedError
|
21 |
+
|
22 |
+
|
23 |
+
@dataclass(kw_only=True, frozen=True)
|
24 |
+
class ToolResult:
|
25 |
+
"""Represents the result of a tool execution."""
|
26 |
+
|
27 |
+
output: str | None = None
|
28 |
+
error: str | None = None
|
29 |
+
base64_image: str | None = None
|
30 |
+
system: str | None = None
|
31 |
+
|
32 |
+
def __bool__(self):
|
33 |
+
return any(getattr(self, field.name) for field in fields(self))
|
34 |
+
|
35 |
+
def __add__(self, other: "ToolResult"):
|
36 |
+
def combine_fields(
|
37 |
+
field: str | None, other_field: str | None, concatenate: bool = True
|
38 |
+
):
|
39 |
+
if field and other_field:
|
40 |
+
if concatenate:
|
41 |
+
return field + other_field
|
42 |
+
raise ValueError("Cannot combine tool results")
|
43 |
+
return field or other_field
|
44 |
+
|
45 |
+
return ToolResult(
|
46 |
+
output=combine_fields(self.output, other.output),
|
47 |
+
error=combine_fields(self.error, other.error),
|
48 |
+
base64_image=combine_fields(self.base64_image, other.base64_image, False),
|
49 |
+
system=combine_fields(self.system, other.system),
|
50 |
+
)
|
51 |
+
|
52 |
+
def replace(self, **kwargs):
|
53 |
+
"""Returns a new ToolResult with the given fields replaced."""
|
54 |
+
return replace(self, **kwargs)
|
55 |
+
|
56 |
+
|
57 |
+
class CLIResult(ToolResult):
|
58 |
+
"""A ToolResult that can be rendered as a CLI output."""
|
59 |
+
|
60 |
+
|
61 |
+
class ToolFailure(ToolResult):
|
62 |
+
"""A ToolResult that represents a failure."""
|
63 |
+
|
64 |
+
|
65 |
+
class ToolError(Exception):
|
66 |
+
"""Raised when a tool encounters an error."""
|
67 |
+
|
68 |
+
def __init__(self, message):
|
69 |
+
self.message = message
|
computer_use_demo/tools/bash.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import os
|
3 |
+
from typing import ClassVar, Literal
|
4 |
+
|
5 |
+
from anthropic.types.beta import BetaToolBash20241022Param
|
6 |
+
|
7 |
+
from .base import BaseAnthropicTool, CLIResult, ToolError, ToolResult
|
8 |
+
|
9 |
+
|
10 |
+
class _BashSession:
|
11 |
+
"""A session of a bash shell."""
|
12 |
+
|
13 |
+
_started: bool
|
14 |
+
_process: asyncio.subprocess.Process
|
15 |
+
|
16 |
+
command: str = "/bin/bash"
|
17 |
+
_output_delay: float = 0.2 # seconds
|
18 |
+
_timeout: float = 120.0 # seconds
|
19 |
+
_sentinel: str = "<<exit>>"
|
20 |
+
|
21 |
+
def __init__(self):
|
22 |
+
self._started = False
|
23 |
+
self._timed_out = False
|
24 |
+
|
25 |
+
async def start(self):
|
26 |
+
if self._started:
|
27 |
+
return
|
28 |
+
|
29 |
+
self._process = await asyncio.create_subprocess_shell(
|
30 |
+
self.command,
|
31 |
+
shell=False,
|
32 |
+
stdin=asyncio.subprocess.PIPE,
|
33 |
+
stdout=asyncio.subprocess.PIPE,
|
34 |
+
stderr=asyncio.subprocess.PIPE,
|
35 |
+
)
|
36 |
+
|
37 |
+
self._started = True
|
38 |
+
|
39 |
+
def stop(self):
|
40 |
+
"""Terminate the bash shell."""
|
41 |
+
if not self._started:
|
42 |
+
raise ToolError("Session has not started.")
|
43 |
+
if self._process.returncode is not None:
|
44 |
+
return
|
45 |
+
self._process.terminate()
|
46 |
+
|
47 |
+
async def run(self, command: str):
|
48 |
+
"""Execute a command in the bash shell."""
|
49 |
+
if not self._started:
|
50 |
+
raise ToolError("Session has not started.")
|
51 |
+
if self._process.returncode is not None:
|
52 |
+
return ToolResult(
|
53 |
+
system="tool must be restarted",
|
54 |
+
error=f"bash has exited with returncode {self._process.returncode}",
|
55 |
+
)
|
56 |
+
if self._timed_out:
|
57 |
+
raise ToolError(
|
58 |
+
f"timed out: bash has not returned in {self._timeout} seconds and must be restarted",
|
59 |
+
)
|
60 |
+
|
61 |
+
# we know these are not None because we created the process with PIPEs
|
62 |
+
assert self._process.stdin
|
63 |
+
assert self._process.stdout
|
64 |
+
assert self._process.stderr
|
65 |
+
|
66 |
+
# send command to the process
|
67 |
+
self._process.stdin.write(
|
68 |
+
command.encode() + f"; echo '{self._sentinel}'\n".encode()
|
69 |
+
)
|
70 |
+
await self._process.stdin.drain()
|
71 |
+
|
72 |
+
# read output from the process, until the sentinel is found
|
73 |
+
output = ""
|
74 |
+
try:
|
75 |
+
async with asyncio.timeout(self._timeout):
|
76 |
+
while True:
|
77 |
+
await asyncio.sleep(self._output_delay)
|
78 |
+
data = await self._process.stdout.readline()
|
79 |
+
if not data:
|
80 |
+
break
|
81 |
+
line = data.decode()
|
82 |
+
output += line
|
83 |
+
if self._sentinel in line:
|
84 |
+
output = output.replace(self._sentinel, "")
|
85 |
+
break
|
86 |
+
except asyncio.TimeoutError:
|
87 |
+
self._timed_out = True
|
88 |
+
raise ToolError(
|
89 |
+
f"timed out: bash has not returned in {self._timeout} seconds and must be restarted",
|
90 |
+
) from None
|
91 |
+
|
92 |
+
error = await self._process.stderr.read()
|
93 |
+
error = error.decode()
|
94 |
+
|
95 |
+
return CLIResult(output=output.strip(), error=error.strip())
|
96 |
+
|
97 |
+
|
98 |
+
class BashTool(BaseAnthropicTool):
|
99 |
+
"""
|
100 |
+
A tool that allows the agent to run bash commands.
|
101 |
+
The tool parameters are defined by Anthropic and are not editable.
|
102 |
+
"""
|
103 |
+
|
104 |
+
_session: _BashSession | None
|
105 |
+
name: ClassVar[Literal["bash"]] = "bash"
|
106 |
+
api_type: ClassVar[Literal["bash_20241022"]] = "bash_20241022"
|
107 |
+
|
108 |
+
def __init__(self):
|
109 |
+
self._session = None
|
110 |
+
super().__init__()
|
111 |
+
|
112 |
+
async def __call__(
|
113 |
+
self, command: str | None = None, restart: bool = False, **kwargs
|
114 |
+
):
|
115 |
+
if restart:
|
116 |
+
if self._session:
|
117 |
+
self._session.stop()
|
118 |
+
self._session = _BashSession()
|
119 |
+
await self._session.start()
|
120 |
+
|
121 |
+
return ToolResult(system="tool has been restarted.")
|
122 |
+
|
123 |
+
if self._session is None:
|
124 |
+
self._session = _BashSession()
|
125 |
+
await self._session.start()
|
126 |
+
|
127 |
+
if command is not None:
|
128 |
+
return await self._session.run(command)
|
129 |
+
|
130 |
+
raise ToolError("no command provided.")
|
131 |
+
|
132 |
+
def to_params(self) -> BetaToolBash20241022Param:
|
133 |
+
return {
|
134 |
+
"type": self.api_type,
|
135 |
+
"name": self.name,
|
136 |
+
}
|
computer_use_demo/tools/collection.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Collection classes for managing multiple tools."""
|
2 |
+
|
3 |
+
from typing import Any
|
4 |
+
|
5 |
+
from anthropic.types.beta import BetaToolUnionParam
|
6 |
+
|
7 |
+
from .base import (
|
8 |
+
BaseAnthropicTool,
|
9 |
+
ToolError,
|
10 |
+
ToolFailure,
|
11 |
+
ToolResult,
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
class ToolCollection:
|
16 |
+
"""A collection of anthropic-defined tools."""
|
17 |
+
|
18 |
+
def __init__(self, *tools: BaseAnthropicTool):
|
19 |
+
self.tools = tools
|
20 |
+
self.tool_map = {tool.to_params()["name"]: tool for tool in tools}
|
21 |
+
|
22 |
+
def to_params(
|
23 |
+
self,
|
24 |
+
) -> list[BetaToolUnionParam]:
|
25 |
+
return [tool.to_params() for tool in self.tools]
|
26 |
+
|
27 |
+
async def run(self, *, name: str, tool_input: dict[str, Any]) -> ToolResult:
|
28 |
+
tool = self.tool_map.get(name)
|
29 |
+
if not tool:
|
30 |
+
return ToolFailure(error=f"Tool {name} is invalid")
|
31 |
+
try:
|
32 |
+
return await tool(**tool_input)
|
33 |
+
except ToolError as e:
|
34 |
+
return ToolFailure(error=e.message)
|
35 |
+
|
36 |
+
def sync_call(self, *, name: str, tool_input: dict[str, Any]) -> ToolResult:
|
37 |
+
print(f"sync_call: {name} {tool_input}")
|
38 |
+
tool = self.tool_map.get(name)
|
39 |
+
if not tool:
|
40 |
+
return ToolFailure(error=f"Tool {name} is invalid")
|
41 |
+
return tool.sync_call(**tool_input)
|
computer_use_demo/tools/colorful_text.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Define some colorful stuffs for better visualization in the chat.
|
3 |
+
"""
|
4 |
+
|
5 |
+
# Define the RGB colors for each letter
|
6 |
+
colors = {
|
7 |
+
'S': 'rgb(106, 158, 210)',
|
8 |
+
'h': 'rgb(111, 163, 82)',
|
9 |
+
'o': 'rgb(209, 100, 94)',
|
10 |
+
'w': 'rgb(238, 171, 106)',
|
11 |
+
'U': 'rgb(0, 0, 0)',
|
12 |
+
'I': 'rgb(0, 0, 0)',
|
13 |
+
}
|
14 |
+
|
15 |
+
# Construct the colorful "ShowUI" word
|
16 |
+
colorful_text_showui = "**"+''.join(
|
17 |
+
f'<span style="color:{colors.get(letter, "black")}">{letter}</span>'
|
18 |
+
for letter in "ShowUI"
|
19 |
+
)+"**"
|
20 |
+
|
21 |
+
|
22 |
+
colorful_text_vlm = "**VLMPlanner**"
|
23 |
+
|
24 |
+
colorful_text_user = "**User**"
|
25 |
+
|
26 |
+
# print(f"colorful_text_showui: {colorful_text_showui}")
|
27 |
+
# **<span style="color:rgb(106, 158, 210)">S</span><span style="color:rgb(111, 163, 82)">h</span><span style="color:rgb(209, 100, 94)">o</span><span style="color:rgb(238, 171, 106)">w</span><span style="color:rgb(0, 0, 0)">U</span><span style="color:rgb(0, 0, 0)">I</span>**
|
computer_use_demo/tools/computer.py
ADDED
@@ -0,0 +1,621 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import subprocess
|
2 |
+
import platform
|
3 |
+
import pyautogui
|
4 |
+
import asyncio
|
5 |
+
import base64
|
6 |
+
import os
|
7 |
+
import time
|
8 |
+
if platform.system() == "Darwin":
|
9 |
+
import Quartz # uncomment this line if you are on macOS
|
10 |
+
from enum import StrEnum
|
11 |
+
from pathlib import Path
|
12 |
+
from typing import Literal, TypedDict
|
13 |
+
from uuid import uuid4
|
14 |
+
from screeninfo import get_monitors
|
15 |
+
|
16 |
+
from PIL import ImageGrab, Image
|
17 |
+
from functools import partial
|
18 |
+
|
19 |
+
from anthropic.types.beta import BetaToolComputerUse20241022Param
|
20 |
+
|
21 |
+
from .base import BaseAnthropicTool, ToolError, ToolResult
|
22 |
+
from .run import run
|
23 |
+
|
24 |
+
OUTPUT_DIR = "./tmp/outputs"
|
25 |
+
|
26 |
+
TYPING_DELAY_MS = 12
|
27 |
+
TYPING_GROUP_SIZE = 50
|
28 |
+
|
29 |
+
Action = Literal[
|
30 |
+
"key",
|
31 |
+
"type",
|
32 |
+
"mouse_move",
|
33 |
+
"left_click",
|
34 |
+
"left_click_drag",
|
35 |
+
"right_click",
|
36 |
+
"middle_click",
|
37 |
+
"double_click",
|
38 |
+
"screenshot",
|
39 |
+
"cursor_position",
|
40 |
+
]
|
41 |
+
|
42 |
+
|
43 |
+
class Resolution(TypedDict):
|
44 |
+
width: int
|
45 |
+
height: int
|
46 |
+
|
47 |
+
|
48 |
+
MAX_SCALING_TARGETS: dict[str, Resolution] = {
|
49 |
+
"XGA": Resolution(width=1024, height=768), # 4:3
|
50 |
+
"WXGA": Resolution(width=1280, height=800), # 16:10
|
51 |
+
"FWXGA": Resolution(width=1366, height=768), # ~16:9
|
52 |
+
}
|
53 |
+
|
54 |
+
|
55 |
+
class ScalingSource(StrEnum):
|
56 |
+
COMPUTER = "computer"
|
57 |
+
API = "api"
|
58 |
+
|
59 |
+
|
60 |
+
class ComputerToolOptions(TypedDict):
|
61 |
+
display_height_px: int
|
62 |
+
display_width_px: int
|
63 |
+
display_number: int | None
|
64 |
+
|
65 |
+
|
66 |
+
def chunks(s: str, chunk_size: int) -> list[str]:
|
67 |
+
return [s[i : i + chunk_size] for i in range(0, len(s), chunk_size)]
|
68 |
+
|
69 |
+
|
70 |
+
def get_screen_details():
|
71 |
+
screens = get_monitors()
|
72 |
+
screen_details = []
|
73 |
+
|
74 |
+
# Sort screens by x position to arrange from left to right
|
75 |
+
sorted_screens = sorted(screens, key=lambda s: s.x)
|
76 |
+
|
77 |
+
# Loop through sorted screens and assign positions
|
78 |
+
primary_index = 0
|
79 |
+
for i, screen in enumerate(sorted_screens):
|
80 |
+
if i == 0:
|
81 |
+
layout = "Left"
|
82 |
+
elif i == len(sorted_screens) - 1:
|
83 |
+
layout = "Right"
|
84 |
+
else:
|
85 |
+
layout = "Center"
|
86 |
+
|
87 |
+
if screen.is_primary:
|
88 |
+
position = "Primary"
|
89 |
+
primary_index = i
|
90 |
+
else:
|
91 |
+
position = "Secondary"
|
92 |
+
screen_info = f"Screen {i + 1}: {screen.width}x{screen.height}, {layout}, {position}"
|
93 |
+
screen_details.append(screen_info)
|
94 |
+
|
95 |
+
return screen_details, primary_index
|
96 |
+
|
97 |
+
|
98 |
+
class ComputerTool(BaseAnthropicTool):
|
99 |
+
"""
|
100 |
+
A tool that allows the agent to interact with the screen, keyboard, and mouse of the current computer.
|
101 |
+
Adapted for Windows using 'pyautogui'.
|
102 |
+
"""
|
103 |
+
|
104 |
+
name: Literal["computer"] = "computer"
|
105 |
+
api_type: Literal["computer_20241022"] = "computer_20241022"
|
106 |
+
width: int
|
107 |
+
height: int
|
108 |
+
display_num: int | None
|
109 |
+
|
110 |
+
_screenshot_delay = 2.0
|
111 |
+
_scaling_enabled = True
|
112 |
+
|
113 |
+
@property
|
114 |
+
def options(self) -> ComputerToolOptions:
|
115 |
+
width, height = self.scale_coordinates(
|
116 |
+
ScalingSource.COMPUTER, self.width, self.height
|
117 |
+
)
|
118 |
+
return {
|
119 |
+
"display_width_px": width,
|
120 |
+
"display_height_px": height,
|
121 |
+
"display_number": self.display_num,
|
122 |
+
}
|
123 |
+
|
124 |
+
def to_params(self) -> BetaToolComputerUse20241022Param:
|
125 |
+
return {"name": self.name, "type": self.api_type, **self.options}
|
126 |
+
|
127 |
+
def __init__(self, selected_screen: int = 0, is_scaling: bool = True):
|
128 |
+
super().__init__()
|
129 |
+
|
130 |
+
# Get screen width and height using Windows command
|
131 |
+
self.display_num = None
|
132 |
+
self.offset_x = 0
|
133 |
+
self.offset_y = 0
|
134 |
+
self.selected_screen = selected_screen
|
135 |
+
self.is_scaling = is_scaling
|
136 |
+
self.width, self.height = self.get_screen_size()
|
137 |
+
|
138 |
+
# Path to cliclick
|
139 |
+
self.cliclick = "cliclick"
|
140 |
+
self.key_conversion = {"Page_Down": "pagedown",
|
141 |
+
"Page_Up": "pageup",
|
142 |
+
"Super_L": "win",
|
143 |
+
"Escape": "esc"}
|
144 |
+
|
145 |
+
self.action_conversion = {"left click": "click",
|
146 |
+
"right click": "right_click"}
|
147 |
+
|
148 |
+
system = platform.system() # Detect platform
|
149 |
+
if system == "Windows":
|
150 |
+
screens = get_monitors()
|
151 |
+
sorted_screens = sorted(screens, key=lambda s: s.x)
|
152 |
+
if self.selected_screen < 0 or self.selected_screen >= len(screens):
|
153 |
+
raise IndexError("Invalid screen index.")
|
154 |
+
screen = sorted_screens[self.selected_screen]
|
155 |
+
bbox = (screen.x, screen.y, screen.x + screen.width, screen.y + screen.height)
|
156 |
+
|
157 |
+
elif system == "Darwin": # macOS
|
158 |
+
max_displays = 32 # Maximum number of displays to handle
|
159 |
+
active_displays = Quartz.CGGetActiveDisplayList(max_displays, None, None)[1]
|
160 |
+
screens = []
|
161 |
+
for display_id in active_displays:
|
162 |
+
bounds = Quartz.CGDisplayBounds(display_id)
|
163 |
+
screens.append({
|
164 |
+
'id': display_id, 'x': int(bounds.origin.x), 'y': int(bounds.origin.y),
|
165 |
+
'width': int(bounds.size.width), 'height': int(bounds.size.height),
|
166 |
+
'is_primary': Quartz.CGDisplayIsMain(display_id) # Check if this is the primary display
|
167 |
+
})
|
168 |
+
sorted_screens = sorted(screens, key=lambda s: s['x'])
|
169 |
+
if self.selected_screen < 0 or self.selected_screen >= len(screens):
|
170 |
+
raise IndexError("Invalid screen index.")
|
171 |
+
screen = sorted_screens[self.selected_screen]
|
172 |
+
bbox = (screen['x'], screen['y'], screen['x'] + screen['width'], screen['y'] + screen['height'])
|
173 |
+
else: # Linux or other OS
|
174 |
+
cmd = "xrandr | grep ' primary' | awk '{print $4}'"
|
175 |
+
try:
|
176 |
+
output = subprocess.check_output(cmd, shell=True).decode()
|
177 |
+
resolution = output.strip().split()[0]
|
178 |
+
width, height = map(int, resolution.split('x'))
|
179 |
+
bbox = (0, 0, width, height) # Assuming single primary screen for simplicity
|
180 |
+
except subprocess.CalledProcessError:
|
181 |
+
raise RuntimeError("Failed to get screen resolution on Linux.")
|
182 |
+
|
183 |
+
self.offset_x = screen['x'] if system == "Darwin" else screen.x
|
184 |
+
self.offset_y = screen['y'] if system == "Darwin" else screen.y
|
185 |
+
self.bbox = bbox
|
186 |
+
|
187 |
+
|
188 |
+
async def __call__(
|
189 |
+
self,
|
190 |
+
*,
|
191 |
+
action: Action,
|
192 |
+
text: str | None = None,
|
193 |
+
coordinate: tuple[int, int] | None = None,
|
194 |
+
**kwargs,
|
195 |
+
):
|
196 |
+
print(f"action: {action}, text: {text}, coordinate: {coordinate}")
|
197 |
+
action = self.action_conversion.get(action, action)
|
198 |
+
|
199 |
+
if action in ("mouse_move", "left_click_drag"):
|
200 |
+
if coordinate is None:
|
201 |
+
raise ToolError(f"coordinate is required for {action}")
|
202 |
+
if text is not None:
|
203 |
+
raise ToolError(f"text is not accepted for {action}")
|
204 |
+
if not isinstance(coordinate, (list, tuple)) or len(coordinate) != 2:
|
205 |
+
raise ToolError(f"{coordinate} must be a tuple of length 2")
|
206 |
+
# if not all(isinstance(i, int) and i >= 0 for i in coordinate):
|
207 |
+
if not all(isinstance(i, int) for i in coordinate):
|
208 |
+
raise ToolError(f"{coordinate} must be a tuple of non-negative ints")
|
209 |
+
|
210 |
+
if self.is_scaling:
|
211 |
+
x, y = self.scale_coordinates(
|
212 |
+
ScalingSource.API, coordinate[0], coordinate[1]
|
213 |
+
)
|
214 |
+
else:
|
215 |
+
x, y = coordinate
|
216 |
+
|
217 |
+
# print(f"scaled_coordinates: {x}, {y}")
|
218 |
+
# print(f"offset: {self.offset_x}, {self.offset_y}")
|
219 |
+
|
220 |
+
x += self.offset_x
|
221 |
+
y += self.offset_y
|
222 |
+
|
223 |
+
print(f"mouse move to {x}, {y}")
|
224 |
+
|
225 |
+
if action == "mouse_move":
|
226 |
+
pyautogui.moveTo(x, y)
|
227 |
+
return ToolResult(output=f"Moved mouse to ({x}, {y})")
|
228 |
+
elif action == "left_click_drag":
|
229 |
+
current_x, current_y = pyautogui.position()
|
230 |
+
pyautogui.dragTo(x, y, duration=0.5) # Adjust duration as needed
|
231 |
+
return ToolResult(output=f"Dragged mouse from ({current_x}, {current_y}) to ({x}, {y})")
|
232 |
+
|
233 |
+
if action in ("key", "type"):
|
234 |
+
if text is None:
|
235 |
+
raise ToolError(f"text is required for {action}")
|
236 |
+
if coordinate is not None:
|
237 |
+
raise ToolError(f"coordinate is not accepted for {action}")
|
238 |
+
if not isinstance(text, str):
|
239 |
+
raise ToolError(output=f"{text} must be a string")
|
240 |
+
|
241 |
+
if action == "key":
|
242 |
+
# Handle key combinations
|
243 |
+
keys = text.split('+')
|
244 |
+
for key in keys:
|
245 |
+
key = self.key_conversion.get(key.strip(), key.strip())
|
246 |
+
key = key.lower()
|
247 |
+
pyautogui.keyDown(key) # Press down each key
|
248 |
+
for key in reversed(keys):
|
249 |
+
key = self.key_conversion.get(key.strip(), key.strip())
|
250 |
+
key = key.lower()
|
251 |
+
pyautogui.keyUp(key) # Release each key in reverse order
|
252 |
+
return ToolResult(output=f"Pressed keys: {text}")
|
253 |
+
|
254 |
+
elif action == "type":
|
255 |
+
pyautogui.typewrite(text, interval=TYPING_DELAY_MS / 1000) # Convert ms to seconds
|
256 |
+
screenshot_base64 = (await self.screenshot()).base64_image
|
257 |
+
return ToolResult(output=text, base64_image=screenshot_base64)
|
258 |
+
|
259 |
+
if action in (
|
260 |
+
"left_click",
|
261 |
+
"right_click",
|
262 |
+
"double_click",
|
263 |
+
"middle_click",
|
264 |
+
"screenshot",
|
265 |
+
"cursor_position",
|
266 |
+
"left_press",
|
267 |
+
):
|
268 |
+
if text is not None:
|
269 |
+
raise ToolError(f"text is not accepted for {action}")
|
270 |
+
if coordinate is not None:
|
271 |
+
raise ToolError(f"coordinate is not accepted for {action}")
|
272 |
+
|
273 |
+
if action == "screenshot":
|
274 |
+
return await self.screenshot()
|
275 |
+
elif action == "cursor_position":
|
276 |
+
x, y = pyautogui.position()
|
277 |
+
x, y = self.scale_coordinates(ScalingSource.COMPUTER, x, y)
|
278 |
+
return ToolResult(output=f"X={x},Y={y}")
|
279 |
+
else:
|
280 |
+
if action == "left_click":
|
281 |
+
pyautogui.click()
|
282 |
+
elif action == "right_click":
|
283 |
+
pyautogui.rightClick()
|
284 |
+
elif action == "middle_click":
|
285 |
+
pyautogui.middleClick()
|
286 |
+
elif action == "double_click":
|
287 |
+
pyautogui.doubleClick()
|
288 |
+
elif action == "left_press":
|
289 |
+
pyautogui.mouseDown()
|
290 |
+
time.sleep(1)
|
291 |
+
pyautogui.mouseUp()
|
292 |
+
return ToolResult(output=f"Performed {action}")
|
293 |
+
|
294 |
+
raise ToolError(f"Invalid action: {action}")
|
295 |
+
|
296 |
+
|
297 |
+
def sync_call(
|
298 |
+
self,
|
299 |
+
*,
|
300 |
+
action: Action,
|
301 |
+
text: str | None = None,
|
302 |
+
coordinate: tuple[int, int] | None = None,
|
303 |
+
**kwargs,
|
304 |
+
):
|
305 |
+
print(f"action: {action}, text: {text}, coordinate: {coordinate}")
|
306 |
+
action = self.action_conversion.get(action, action)
|
307 |
+
|
308 |
+
if action in ("mouse_move", "left_click_drag"):
|
309 |
+
if coordinate is None:
|
310 |
+
raise ToolError(f"coordinate is required for {action}")
|
311 |
+
if text is not None:
|
312 |
+
raise ToolError(f"text is not accepted for {action}")
|
313 |
+
if not isinstance(coordinate, (list, tuple)) or len(coordinate) != 2:
|
314 |
+
raise ToolError(f"{coordinate} must be a tuple of length 2")
|
315 |
+
# if not all(isinstance(i, int) and i >= 0 for i in coordinate):
|
316 |
+
if not all(isinstance(i, int) for i in coordinate):
|
317 |
+
raise ToolError(f"{coordinate} must be a tuple of non-negative ints")
|
318 |
+
|
319 |
+
if self.is_scaling:
|
320 |
+
x, y = self.scale_coordinates(
|
321 |
+
ScalingSource.API, coordinate[0], coordinate[1]
|
322 |
+
)
|
323 |
+
else:
|
324 |
+
x, y = coordinate
|
325 |
+
|
326 |
+
# print(f"scaled_coordinates: {x}, {y}")
|
327 |
+
# print(f"offset: {self.offset_x}, {self.offset_y}")
|
328 |
+
x += self.offset_x
|
329 |
+
y += self.offset_y
|
330 |
+
|
331 |
+
print(f"mouse move to {x}, {y}")
|
332 |
+
|
333 |
+
if action == "mouse_move":
|
334 |
+
pyautogui.moveTo(x, y)
|
335 |
+
return ToolResult(output=f"Moved mouse to ({x}, {y})")
|
336 |
+
elif action == "left_click_drag":
|
337 |
+
current_x, current_y = pyautogui.position()
|
338 |
+
pyautogui.dragTo(x, y, duration=0.5) # Adjust duration as needed
|
339 |
+
return ToolResult(output=f"Dragged mouse from ({current_x}, {current_y}) to ({x}, {y})")
|
340 |
+
|
341 |
+
if action in ("key", "type"):
|
342 |
+
if text is None:
|
343 |
+
raise ToolError(f"text is required for {action}")
|
344 |
+
if coordinate is not None:
|
345 |
+
raise ToolError(f"coordinate is not accepted for {action}")
|
346 |
+
if not isinstance(text, str):
|
347 |
+
raise ToolError(output=f"{text} must be a string")
|
348 |
+
|
349 |
+
if action == "key":
|
350 |
+
# Handle key combinations
|
351 |
+
keys = text.split('+')
|
352 |
+
for key in keys:
|
353 |
+
key = self.key_conversion.get(key.strip(), key.strip())
|
354 |
+
key = key.lower()
|
355 |
+
pyautogui.keyDown(key) # Press down each key
|
356 |
+
for key in reversed(keys):
|
357 |
+
key = self.key_conversion.get(key.strip(), key.strip())
|
358 |
+
key = key.lower()
|
359 |
+
pyautogui.keyUp(key) # Release each key in reverse order
|
360 |
+
return ToolResult(output=f"Pressed keys: {text}")
|
361 |
+
|
362 |
+
elif action == "type":
|
363 |
+
pyautogui.typewrite(text, interval=TYPING_DELAY_MS / 1000) # Convert ms to seconds
|
364 |
+
return ToolResult(output=text)
|
365 |
+
|
366 |
+
if action in (
|
367 |
+
"left_click",
|
368 |
+
"right_click",
|
369 |
+
"double_click",
|
370 |
+
"middle_click",
|
371 |
+
"screenshot",
|
372 |
+
"cursor_position",
|
373 |
+
"left_press",
|
374 |
+
):
|
375 |
+
if text is not None:
|
376 |
+
raise ToolError(f"text is not accepted for {action}")
|
377 |
+
if coordinate is not None:
|
378 |
+
raise ToolError(f"coordinate is not accepted for {action}")
|
379 |
+
elif action == "cursor_position":
|
380 |
+
x, y = pyautogui.position()
|
381 |
+
x, y = self.scale_coordinates(ScalingSource.COMPUTER, x, y)
|
382 |
+
return ToolResult(output=f"X={x},Y={y}")
|
383 |
+
else:
|
384 |
+
if action == "left_click":
|
385 |
+
pyautogui.click()
|
386 |
+
elif action == "right_click":
|
387 |
+
pyautogui.rightClick()
|
388 |
+
elif action == "middle_click":
|
389 |
+
pyautogui.middleClick()
|
390 |
+
elif action == "double_click":
|
391 |
+
pyautogui.doubleClick()
|
392 |
+
elif action == "left_press":
|
393 |
+
pyautogui.mouseDown()
|
394 |
+
time.sleep(1)
|
395 |
+
pyautogui.mouseUp()
|
396 |
+
return ToolResult(output=f"Performed {action}")
|
397 |
+
|
398 |
+
raise ToolError(f"Invalid action: {action}")
|
399 |
+
|
400 |
+
async def screenshot(self):
|
401 |
+
|
402 |
+
import time
|
403 |
+
time.sleep(1)
|
404 |
+
|
405 |
+
"""Take a screenshot of the current screen and return a ToolResult with the base64 encoded image."""
|
406 |
+
output_dir = Path(OUTPUT_DIR)
|
407 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
408 |
+
path = output_dir / f"screenshot_{uuid4().hex}.png"
|
409 |
+
|
410 |
+
ImageGrab.grab = partial(ImageGrab.grab, all_screens=True)
|
411 |
+
|
412 |
+
# Detect platform
|
413 |
+
system = platform.system()
|
414 |
+
|
415 |
+
if system == "Windows":
|
416 |
+
# Windows: Use screeninfo to get monitor details
|
417 |
+
screens = get_monitors()
|
418 |
+
|
419 |
+
# Sort screens by x position to arrange from left to right
|
420 |
+
sorted_screens = sorted(screens, key=lambda s: s.x)
|
421 |
+
|
422 |
+
if self.selected_screen < 0 or self.selected_screen >= len(screens):
|
423 |
+
raise IndexError("Invalid screen index.")
|
424 |
+
|
425 |
+
screen = sorted_screens[self.selected_screen]
|
426 |
+
bbox = (screen.x, screen.y, screen.x + screen.width, screen.y + screen.height)
|
427 |
+
|
428 |
+
elif system == "Darwin": # macOS
|
429 |
+
# macOS: Use Quartz to get monitor details
|
430 |
+
max_displays = 32 # Maximum number of displays to handle
|
431 |
+
active_displays = Quartz.CGGetActiveDisplayList(max_displays, None, None)[1]
|
432 |
+
|
433 |
+
# Get the display bounds (resolution) for each active display
|
434 |
+
screens = []
|
435 |
+
for display_id in active_displays:
|
436 |
+
bounds = Quartz.CGDisplayBounds(display_id)
|
437 |
+
screens.append({
|
438 |
+
'id': display_id,
|
439 |
+
'x': int(bounds.origin.x),
|
440 |
+
'y': int(bounds.origin.y),
|
441 |
+
'width': int(bounds.size.width),
|
442 |
+
'height': int(bounds.size.height),
|
443 |
+
'is_primary': Quartz.CGDisplayIsMain(display_id) # Check if this is the primary display
|
444 |
+
})
|
445 |
+
|
446 |
+
# Sort screens by x position to arrange from left to right
|
447 |
+
sorted_screens = sorted(screens, key=lambda s: s['x'])
|
448 |
+
|
449 |
+
if self.selected_screen < 0 or self.selected_screen >= len(screens):
|
450 |
+
raise IndexError("Invalid screen index.")
|
451 |
+
|
452 |
+
screen = sorted_screens[self.selected_screen]
|
453 |
+
bbox = (screen['x'], screen['y'], screen['x'] + screen['width'], screen['y'] + screen['height'])
|
454 |
+
|
455 |
+
else: # Linux or other OS
|
456 |
+
cmd = "xrandr | grep ' primary' | awk '{print $4}'"
|
457 |
+
try:
|
458 |
+
output = subprocess.check_output(cmd, shell=True).decode()
|
459 |
+
resolution = output.strip().split()[0]
|
460 |
+
width, height = map(int, resolution.split('x'))
|
461 |
+
bbox = (0, 0, width, height) # Assuming single primary screen for simplicity
|
462 |
+
except subprocess.CalledProcessError:
|
463 |
+
raise RuntimeError("Failed to get screen resolution on Linux.")
|
464 |
+
|
465 |
+
# Take screenshot using the bounding box
|
466 |
+
screenshot = ImageGrab.grab(bbox=bbox)
|
467 |
+
|
468 |
+
# Set offsets (for potential future use)
|
469 |
+
self.offset_x = screen['x'] if system == "Darwin" else screen.x
|
470 |
+
self.offset_y = screen['y'] if system == "Darwin" else screen.y
|
471 |
+
|
472 |
+
print(f"target_dimension {self.target_dimension}")
|
473 |
+
|
474 |
+
if not hasattr(self, 'target_dimension'):
|
475 |
+
screenshot = self.padding_image(screenshot)
|
476 |
+
self.target_dimension = MAX_SCALING_TARGETS["WXGA"]
|
477 |
+
|
478 |
+
# Resize if target_dimensions are specified
|
479 |
+
print(f"offset is {self.offset_x}, {self.offset_y}")
|
480 |
+
print(f"target_dimension is {self.target_dimension}")
|
481 |
+
screenshot = screenshot.resize((self.target_dimension["width"], self.target_dimension["height"]))
|
482 |
+
|
483 |
+
# Save the screenshot
|
484 |
+
screenshot.save(str(path))
|
485 |
+
|
486 |
+
if path.exists():
|
487 |
+
# Return a ToolResult instance instead of a dictionary
|
488 |
+
return ToolResult(base64_image=base64.b64encode(path.read_bytes()).decode())
|
489 |
+
|
490 |
+
raise ToolError(f"Failed to take screenshot: {path} does not exist.")
|
491 |
+
|
492 |
+
def padding_image(self, screenshot):
|
493 |
+
"""Pad the screenshot to 16:10 aspect ratio, when the aspect ratio is not 16:10."""
|
494 |
+
_, height = screenshot.size
|
495 |
+
new_width = height * 16 // 10
|
496 |
+
|
497 |
+
padding_image = Image.new("RGB", (new_width, height), (255, 255, 255))
|
498 |
+
# padding to top left
|
499 |
+
padding_image.paste(screenshot, (0, 0))
|
500 |
+
return padding_image
|
501 |
+
|
502 |
+
async def shell(self, command: str, take_screenshot=True) -> ToolResult:
|
503 |
+
"""Run a shell command and return the output, error, and optionally a screenshot."""
|
504 |
+
_, stdout, stderr = await run(command)
|
505 |
+
base64_image = None
|
506 |
+
|
507 |
+
if take_screenshot:
|
508 |
+
# delay to let things settle before taking a screenshot
|
509 |
+
await asyncio.sleep(self._screenshot_delay)
|
510 |
+
base64_image = (await self.screenshot()).base64_image
|
511 |
+
|
512 |
+
return ToolResult(output=stdout, error=stderr, base64_image=base64_image)
|
513 |
+
|
514 |
+
def scale_coordinates(self, source: ScalingSource, x: int, y: int):
|
515 |
+
"""Scale coordinates to a target maximum resolution."""
|
516 |
+
if not self._scaling_enabled:
|
517 |
+
return x, y
|
518 |
+
ratio = self.width / self.height
|
519 |
+
target_dimension = None
|
520 |
+
|
521 |
+
for target_name, dimension in MAX_SCALING_TARGETS.items():
|
522 |
+
# allow some error in the aspect ratio - not ratios are exactly 16:9
|
523 |
+
if abs(dimension["width"] / dimension["height"] - ratio) < 0.02:
|
524 |
+
if dimension["width"] < self.width:
|
525 |
+
target_dimension = dimension
|
526 |
+
self.target_dimension = target_dimension
|
527 |
+
# print(f"target_dimension: {target_dimension}")
|
528 |
+
break
|
529 |
+
|
530 |
+
if target_dimension is None:
|
531 |
+
# TODO: currently we force the target to be WXGA (16:10), when it cannot find a match
|
532 |
+
target_dimension = MAX_SCALING_TARGETS["WXGA"]
|
533 |
+
self.target_dimension = MAX_SCALING_TARGETS["WXGA"]
|
534 |
+
|
535 |
+
# should be less than 1
|
536 |
+
x_scaling_factor = target_dimension["width"] / self.width
|
537 |
+
y_scaling_factor = target_dimension["height"] / self.height
|
538 |
+
if source == ScalingSource.API:
|
539 |
+
if x > self.width or y > self.height:
|
540 |
+
raise ToolError(f"Coordinates {x}, {y} are out of bounds")
|
541 |
+
# scale up
|
542 |
+
return round(x / x_scaling_factor), round(y / y_scaling_factor)
|
543 |
+
# scale down
|
544 |
+
return round(x * x_scaling_factor), round(y * y_scaling_factor)
|
545 |
+
|
546 |
+
def get_screen_size(self):
|
547 |
+
if platform.system() == "Windows":
|
548 |
+
# Use screeninfo to get primary monitor on Windows
|
549 |
+
screens = get_monitors()
|
550 |
+
|
551 |
+
# Sort screens by x position to arrange from left to right
|
552 |
+
sorted_screens = sorted(screens, key=lambda s: s.x)
|
553 |
+
|
554 |
+
if self.selected_screen is None:
|
555 |
+
primary_monitor = next((m for m in get_monitors() if m.is_primary), None)
|
556 |
+
return primary_monitor.width, primary_monitor.height
|
557 |
+
elif self.selected_screen < 0 or self.selected_screen >= len(screens):
|
558 |
+
raise IndexError("Invalid screen index.")
|
559 |
+
else:
|
560 |
+
screen = sorted_screens[self.selected_screen]
|
561 |
+
return screen.width, screen.height
|
562 |
+
|
563 |
+
elif platform.system() == "Darwin":
|
564 |
+
# macOS part using Quartz to get screen information
|
565 |
+
max_displays = 32 # Maximum number of displays to handle
|
566 |
+
active_displays = Quartz.CGGetActiveDisplayList(max_displays, None, None)[1]
|
567 |
+
|
568 |
+
# Get the display bounds (resolution) for each active display
|
569 |
+
screens = []
|
570 |
+
for display_id in active_displays:
|
571 |
+
bounds = Quartz.CGDisplayBounds(display_id)
|
572 |
+
screens.append({
|
573 |
+
'id': display_id,
|
574 |
+
'x': int(bounds.origin.x),
|
575 |
+
'y': int(bounds.origin.y),
|
576 |
+
'width': int(bounds.size.width),
|
577 |
+
'height': int(bounds.size.height),
|
578 |
+
'is_primary': Quartz.CGDisplayIsMain(display_id) # Check if this is the primary display
|
579 |
+
})
|
580 |
+
|
581 |
+
# Sort screens by x position to arrange from left to right
|
582 |
+
sorted_screens = sorted(screens, key=lambda s: s['x'])
|
583 |
+
|
584 |
+
if self.selected_screen is None:
|
585 |
+
# Find the primary monitor
|
586 |
+
primary_monitor = next((screen for screen in screens if screen['is_primary']), None)
|
587 |
+
if primary_monitor:
|
588 |
+
return primary_monitor['width'], primary_monitor['height']
|
589 |
+
else:
|
590 |
+
raise RuntimeError("No primary monitor found.")
|
591 |
+
elif self.selected_screen < 0 or self.selected_screen >= len(screens):
|
592 |
+
raise IndexError("Invalid screen index.")
|
593 |
+
else:
|
594 |
+
# Return the resolution of the selected screen
|
595 |
+
screen = sorted_screens[self.selected_screen]
|
596 |
+
return screen['width'], screen['height']
|
597 |
+
|
598 |
+
else: # Linux or other OS
|
599 |
+
cmd = "xrandr | grep ' primary' | awk '{print $4}'"
|
600 |
+
try:
|
601 |
+
output = subprocess.check_output(cmd, shell=True).decode()
|
602 |
+
resolution = output.strip().split()[0]
|
603 |
+
width, height = map(int, resolution.split('x'))
|
604 |
+
return width, height
|
605 |
+
except subprocess.CalledProcessError:
|
606 |
+
raise RuntimeError("Failed to get screen resolution on Linux.")
|
607 |
+
|
608 |
+
def get_mouse_position(self):
|
609 |
+
# TODO: enhance this func
|
610 |
+
from AppKit import NSEvent
|
611 |
+
from Quartz import CGEventSourceCreate, kCGEventSourceStateCombinedSessionState
|
612 |
+
|
613 |
+
loc = NSEvent.mouseLocation()
|
614 |
+
# Adjust for different coordinate system
|
615 |
+
return int(loc.x), int(self.height - loc.y)
|
616 |
+
|
617 |
+
def map_keys(self, text: str):
|
618 |
+
"""Map text to cliclick key codes if necessary."""
|
619 |
+
# For simplicity, return text as is
|
620 |
+
# Implement mapping if special keys are needed
|
621 |
+
return text
|
computer_use_demo/tools/edit.py
ADDED
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Literal, get_args
|
4 |
+
|
5 |
+
from anthropic.types.beta import BetaToolTextEditor20241022Param
|
6 |
+
|
7 |
+
from .base import BaseAnthropicTool, CLIResult, ToolError, ToolResult
|
8 |
+
from .run import maybe_truncate, run
|
9 |
+
|
10 |
+
Command = Literal[
|
11 |
+
"view",
|
12 |
+
"create",
|
13 |
+
"str_replace",
|
14 |
+
"insert",
|
15 |
+
"undo_edit",
|
16 |
+
]
|
17 |
+
SNIPPET_LINES: int = 4
|
18 |
+
|
19 |
+
|
20 |
+
class EditTool(BaseAnthropicTool):
|
21 |
+
"""
|
22 |
+
An filesystem editor tool that allows the agent to view, create, and edit files.
|
23 |
+
The tool parameters are defined by Anthropic and are not editable.
|
24 |
+
"""
|
25 |
+
|
26 |
+
api_type: Literal["text_editor_20241022"] = "text_editor_20241022"
|
27 |
+
name: Literal["str_replace_editor"] = "str_replace_editor"
|
28 |
+
|
29 |
+
_file_history: dict[Path, list[str]]
|
30 |
+
|
31 |
+
def __init__(self):
|
32 |
+
self._file_history = defaultdict(list)
|
33 |
+
super().__init__()
|
34 |
+
|
35 |
+
def to_params(self) -> BetaToolTextEditor20241022Param:
|
36 |
+
return {
|
37 |
+
"name": self.name,
|
38 |
+
"type": self.api_type,
|
39 |
+
}
|
40 |
+
|
41 |
+
async def __call__(
|
42 |
+
self,
|
43 |
+
*,
|
44 |
+
command: Command,
|
45 |
+
path: str,
|
46 |
+
file_text: str | None = None,
|
47 |
+
view_range: list[int] | None = None,
|
48 |
+
old_str: str | None = None,
|
49 |
+
new_str: str | None = None,
|
50 |
+
insert_line: int | None = None,
|
51 |
+
**kwargs,
|
52 |
+
):
|
53 |
+
_path = Path(path)
|
54 |
+
self.validate_path(command, _path)
|
55 |
+
if command == "view":
|
56 |
+
return await self.view(_path, view_range)
|
57 |
+
elif command == "create":
|
58 |
+
if not file_text:
|
59 |
+
raise ToolError("Parameter `file_text` is required for command: create")
|
60 |
+
self.write_file(_path, file_text)
|
61 |
+
self._file_history[_path].append(file_text)
|
62 |
+
return ToolResult(output=f"File created successfully at: {_path}")
|
63 |
+
elif command == "str_replace":
|
64 |
+
if not old_str:
|
65 |
+
raise ToolError(
|
66 |
+
"Parameter `old_str` is required for command: str_replace"
|
67 |
+
)
|
68 |
+
return self.str_replace(_path, old_str, new_str)
|
69 |
+
elif command == "insert":
|
70 |
+
if insert_line is None:
|
71 |
+
raise ToolError(
|
72 |
+
"Parameter `insert_line` is required for command: insert"
|
73 |
+
)
|
74 |
+
if not new_str:
|
75 |
+
raise ToolError("Parameter `new_str` is required for command: insert")
|
76 |
+
return self.insert(_path, insert_line, new_str)
|
77 |
+
elif command == "undo_edit":
|
78 |
+
return self.undo_edit(_path)
|
79 |
+
raise ToolError(
|
80 |
+
f'Unrecognized command {command}. The allowed commands for the {self.name} tool are: {", ".join(get_args(Command))}'
|
81 |
+
)
|
82 |
+
|
83 |
+
def validate_path(self, command: str, path: Path):
|
84 |
+
"""
|
85 |
+
Check that the path/command combination is valid.
|
86 |
+
"""
|
87 |
+
# Check if its an absolute path
|
88 |
+
if not path.is_absolute():
|
89 |
+
suggested_path = Path("") / path
|
90 |
+
raise ToolError(
|
91 |
+
f"The path {path} is not an absolute path, it should start with `/`. Maybe you meant {suggested_path}?"
|
92 |
+
)
|
93 |
+
# Check if path exists
|
94 |
+
if not path.exists() and command != "create":
|
95 |
+
raise ToolError(
|
96 |
+
f"The path {path} does not exist. Please provide a valid path."
|
97 |
+
)
|
98 |
+
if path.exists() and command == "create":
|
99 |
+
raise ToolError(
|
100 |
+
f"File already exists at: {path}. Cannot overwrite files using command `create`."
|
101 |
+
)
|
102 |
+
# Check if the path points to a directory
|
103 |
+
if path.is_dir():
|
104 |
+
if command != "view":
|
105 |
+
raise ToolError(
|
106 |
+
f"The path {path} is a directory and only the `view` command can be used on directories"
|
107 |
+
)
|
108 |
+
|
109 |
+
async def view(self, path: Path, view_range: list[int] | None = None):
|
110 |
+
"""Implement the view command"""
|
111 |
+
if path.is_dir():
|
112 |
+
if view_range:
|
113 |
+
raise ToolError(
|
114 |
+
"The `view_range` parameter is not allowed when `path` points to a directory."
|
115 |
+
)
|
116 |
+
|
117 |
+
_, stdout, stderr = await run(
|
118 |
+
rf"find {path} -maxdepth 2 -not -path '*/\.*'"
|
119 |
+
)
|
120 |
+
if not stderr:
|
121 |
+
stdout = f"Here's the files and directories up to 2 levels deep in {path}, excluding hidden items:\n{stdout}\n"
|
122 |
+
return CLIResult(output=stdout, error=stderr)
|
123 |
+
|
124 |
+
file_content = self.read_file(path)
|
125 |
+
init_line = 1
|
126 |
+
if view_range:
|
127 |
+
if len(view_range) != 2 or not all(isinstance(i, int) for i in view_range):
|
128 |
+
raise ToolError(
|
129 |
+
"Invalid `view_range`. It should be a list of two integers."
|
130 |
+
)
|
131 |
+
file_lines = file_content.split("\n")
|
132 |
+
n_lines_file = len(file_lines)
|
133 |
+
init_line, final_line = view_range
|
134 |
+
if init_line < 1 or init_line > n_lines_file:
|
135 |
+
raise ToolError(
|
136 |
+
f"Invalid `view_range`: {view_range}. It's first element `{init_line}` should be within the range of lines of the file: {[1, n_lines_file]}"
|
137 |
+
)
|
138 |
+
if final_line > n_lines_file:
|
139 |
+
raise ToolError(
|
140 |
+
f"Invalid `view_range`: {view_range}. It's second element `{final_line}` should be smaller than the number of lines in the file: `{n_lines_file}`"
|
141 |
+
)
|
142 |
+
if final_line != -1 and final_line < init_line:
|
143 |
+
raise ToolError(
|
144 |
+
f"Invalid `view_range`: {view_range}. It's second element `{final_line}` should be larger or equal than its first `{init_line}`"
|
145 |
+
)
|
146 |
+
|
147 |
+
if final_line == -1:
|
148 |
+
file_content = "\n".join(file_lines[init_line - 1 :])
|
149 |
+
else:
|
150 |
+
file_content = "\n".join(file_lines[init_line - 1 : final_line])
|
151 |
+
|
152 |
+
return CLIResult(
|
153 |
+
output=self._make_output(file_content, str(path), init_line=init_line)
|
154 |
+
)
|
155 |
+
|
156 |
+
def str_replace(self, path: Path, old_str: str, new_str: str | None):
|
157 |
+
"""Implement the str_replace command, which replaces old_str with new_str in the file content"""
|
158 |
+
# Read the file content
|
159 |
+
file_content = self.read_file(path).expandtabs()
|
160 |
+
old_str = old_str.expandtabs()
|
161 |
+
new_str = new_str.expandtabs() if new_str is not None else ""
|
162 |
+
|
163 |
+
# Check if old_str is unique in the file
|
164 |
+
occurrences = file_content.count(old_str)
|
165 |
+
if occurrences == 0:
|
166 |
+
raise ToolError(
|
167 |
+
f"No replacement was performed, old_str `{old_str}` did not appear verbatim in {path}."
|
168 |
+
)
|
169 |
+
elif occurrences > 1:
|
170 |
+
file_content_lines = file_content.split("\n")
|
171 |
+
lines = [
|
172 |
+
idx + 1
|
173 |
+
for idx, line in enumerate(file_content_lines)
|
174 |
+
if old_str in line
|
175 |
+
]
|
176 |
+
raise ToolError(
|
177 |
+
f"No replacement was performed. Multiple occurrences of old_str `{old_str}` in lines {lines}. Please ensure it is unique"
|
178 |
+
)
|
179 |
+
|
180 |
+
# Replace old_str with new_str
|
181 |
+
new_file_content = file_content.replace(old_str, new_str)
|
182 |
+
|
183 |
+
# Write the new content to the file
|
184 |
+
self.write_file(path, new_file_content)
|
185 |
+
|
186 |
+
# Save the content to history
|
187 |
+
self._file_history[path].append(file_content)
|
188 |
+
|
189 |
+
# Create a snippet of the edited section
|
190 |
+
replacement_line = file_content.split(old_str)[0].count("\n")
|
191 |
+
start_line = max(0, replacement_line - SNIPPET_LINES)
|
192 |
+
end_line = replacement_line + SNIPPET_LINES + new_str.count("\n")
|
193 |
+
snippet = "\n".join(new_file_content.split("\n")[start_line : end_line + 1])
|
194 |
+
|
195 |
+
# Prepare the success message
|
196 |
+
success_msg = f"The file {path} has been edited. "
|
197 |
+
success_msg += self._make_output(
|
198 |
+
snippet, f"a snippet of {path}", start_line + 1
|
199 |
+
)
|
200 |
+
success_msg += "Review the changes and make sure they are as expected. Edit the file again if necessary."
|
201 |
+
|
202 |
+
return CLIResult(output=success_msg)
|
203 |
+
|
204 |
+
def insert(self, path: Path, insert_line: int, new_str: str):
|
205 |
+
"""Implement the insert command, which inserts new_str at the specified line in the file content."""
|
206 |
+
file_text = self.read_file(path).expandtabs()
|
207 |
+
new_str = new_str.expandtabs()
|
208 |
+
file_text_lines = file_text.split("\n")
|
209 |
+
n_lines_file = len(file_text_lines)
|
210 |
+
|
211 |
+
if insert_line < 0 or insert_line > n_lines_file:
|
212 |
+
raise ToolError(
|
213 |
+
f"Invalid `insert_line` parameter: {insert_line}. It should be within the range of lines of the file: {[0, n_lines_file]}"
|
214 |
+
)
|
215 |
+
|
216 |
+
new_str_lines = new_str.split("\n")
|
217 |
+
new_file_text_lines = (
|
218 |
+
file_text_lines[:insert_line]
|
219 |
+
+ new_str_lines
|
220 |
+
+ file_text_lines[insert_line:]
|
221 |
+
)
|
222 |
+
snippet_lines = (
|
223 |
+
file_text_lines[max(0, insert_line - SNIPPET_LINES) : insert_line]
|
224 |
+
+ new_str_lines
|
225 |
+
+ file_text_lines[insert_line : insert_line + SNIPPET_LINES]
|
226 |
+
)
|
227 |
+
|
228 |
+
new_file_text = "\n".join(new_file_text_lines)
|
229 |
+
snippet = "\n".join(snippet_lines)
|
230 |
+
|
231 |
+
self.write_file(path, new_file_text)
|
232 |
+
self._file_history[path].append(file_text)
|
233 |
+
|
234 |
+
success_msg = f"The file {path} has been edited. "
|
235 |
+
success_msg += self._make_output(
|
236 |
+
snippet,
|
237 |
+
"a snippet of the edited file",
|
238 |
+
max(1, insert_line - SNIPPET_LINES + 1),
|
239 |
+
)
|
240 |
+
success_msg += "Review the changes and make sure they are as expected (correct indentation, no duplicate lines, etc). Edit the file again if necessary."
|
241 |
+
return CLIResult(output=success_msg)
|
242 |
+
|
243 |
+
def undo_edit(self, path: Path):
|
244 |
+
"""Implement the undo_edit command."""
|
245 |
+
if not self._file_history[path]:
|
246 |
+
raise ToolError(f"No edit history found for {path}.")
|
247 |
+
|
248 |
+
old_text = self._file_history[path].pop()
|
249 |
+
self.write_file(path, old_text)
|
250 |
+
|
251 |
+
return CLIResult(
|
252 |
+
output=f"Last edit to {path} undone successfully. {self._make_output(old_text, str(path))}"
|
253 |
+
)
|
254 |
+
|
255 |
+
def read_file(self, path: Path):
|
256 |
+
"""Read the content of a file from a given path; raise a ToolError if an error occurs."""
|
257 |
+
try:
|
258 |
+
return path.read_text()
|
259 |
+
except Exception as e:
|
260 |
+
raise ToolError(f"Ran into {e} while trying to read {path}") from None
|
261 |
+
|
262 |
+
def write_file(self, path: Path, file: str):
|
263 |
+
"""Write the content of a file to a given path; raise a ToolError if an error occurs."""
|
264 |
+
try:
|
265 |
+
path.write_text(file)
|
266 |
+
except Exception as e:
|
267 |
+
raise ToolError(f"Ran into {e} while trying to write to {path}") from None
|
268 |
+
|
269 |
+
def _make_output(
|
270 |
+
self,
|
271 |
+
file_content: str,
|
272 |
+
file_descriptor: str,
|
273 |
+
init_line: int = 1,
|
274 |
+
expand_tabs: bool = True,
|
275 |
+
):
|
276 |
+
"""Generate output for the CLI based on the content of a file."""
|
277 |
+
file_content = maybe_truncate(file_content)
|
278 |
+
if expand_tabs:
|
279 |
+
file_content = file_content.expandtabs()
|
280 |
+
file_content = "\n".join(
|
281 |
+
[
|
282 |
+
f"{i + init_line:6}\t{line}"
|
283 |
+
for i, line in enumerate(file_content.split("\n"))
|
284 |
+
]
|
285 |
+
)
|
286 |
+
return (
|
287 |
+
f"Here's the result of running `cat -n` on {file_descriptor}:\n"
|
288 |
+
+ file_content
|
289 |
+
+ "\n"
|
290 |
+
)
|
computer_use_demo/tools/logger.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
|
4 |
+
def truncate_string(s, max_length=500):
|
5 |
+
"""Truncate long strings for concise printing."""
|
6 |
+
if isinstance(s, str) and len(s) > max_length:
|
7 |
+
return s[:max_length] + "..."
|
8 |
+
return s
|
9 |
+
|
10 |
+
# Configure logger
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
logger.setLevel(logging.INFO) # Choose your default level (INFO, DEBUG, etc.)
|
13 |
+
|
14 |
+
|
15 |
+
# Optionally add a console handler if you don't have one already
|
16 |
+
if not logger.handlers:
|
17 |
+
console_handler = logging.StreamHandler()
|
18 |
+
console_handler.setLevel(logging.INFO)
|
19 |
+
formatter = logging.Formatter("[%(levelname)s] %(name)s - %(message)s")
|
20 |
+
console_handler.setFormatter(formatter)
|
21 |
+
logger.addHandler(console_handler)
|
computer_use_demo/tools/run.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utility to run shell commands asynchronously with a timeout."""
|
2 |
+
|
3 |
+
import asyncio
|
4 |
+
|
5 |
+
TRUNCATED_MESSAGE: str = "<response clipped><NOTE>To save on context only part of this file has been shown to you. You should retry this tool after you have searched inside the file with `grep -n` in order to find the line numbers of what you are looking for.</NOTE>"
|
6 |
+
MAX_RESPONSE_LEN: int = 16000
|
7 |
+
|
8 |
+
|
9 |
+
def maybe_truncate(content: str, truncate_after: int | None = MAX_RESPONSE_LEN):
|
10 |
+
"""Truncate content and append a notice if content exceeds the specified length."""
|
11 |
+
return (
|
12 |
+
content
|
13 |
+
if not truncate_after or len(content) <= truncate_after
|
14 |
+
else content[:truncate_after] + TRUNCATED_MESSAGE
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
async def run(
|
19 |
+
cmd: str,
|
20 |
+
timeout: float | None = 120.0, # seconds
|
21 |
+
truncate_after: int | None = MAX_RESPONSE_LEN,
|
22 |
+
):
|
23 |
+
"""Run a shell command asynchronously with a timeout."""
|
24 |
+
process = await asyncio.create_subprocess_shell(
|
25 |
+
cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
26 |
+
)
|
27 |
+
|
28 |
+
try:
|
29 |
+
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
|
30 |
+
return (
|
31 |
+
process.returncode or 0,
|
32 |
+
maybe_truncate(stdout.decode(), truncate_after=truncate_after),
|
33 |
+
maybe_truncate(stderr.decode(), truncate_after=truncate_after),
|
34 |
+
)
|
35 |
+
except asyncio.TimeoutError as exc:
|
36 |
+
try:
|
37 |
+
process.kill()
|
38 |
+
except ProcessLookupError:
|
39 |
+
pass
|
40 |
+
raise TimeoutError(
|
41 |
+
f"Command '{cmd}' timed out after {timeout} seconds"
|
42 |
+
) from exc
|
computer_use_demo/tools/screen_capture.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import subprocess
|
2 |
+
import base64
|
3 |
+
from pathlib import Path
|
4 |
+
from PIL import ImageGrab
|
5 |
+
from uuid import uuid4
|
6 |
+
from screeninfo import get_monitors
|
7 |
+
import platform
|
8 |
+
if platform.system() == "Darwin":
|
9 |
+
import Quartz # uncomment this line if you are on macOS
|
10 |
+
|
11 |
+
from PIL import ImageGrab
|
12 |
+
from functools import partial
|
13 |
+
from .base import BaseAnthropicTool, ToolError, ToolResult
|
14 |
+
|
15 |
+
|
16 |
+
OUTPUT_DIR = "./tmp/outputs"
|
17 |
+
|
18 |
+
def get_screenshot(selected_screen: int = 0, resize: bool = True, target_width: int = 1920, target_height: int = 1080):
|
19 |
+
# print(f"get_screenshot selected_screen: {selected_screen}")
|
20 |
+
|
21 |
+
# Get screen width and height using Windows command
|
22 |
+
display_num = None
|
23 |
+
offset_x = 0
|
24 |
+
offset_y = 0
|
25 |
+
selected_screen = selected_screen
|
26 |
+
width, height = _get_screen_size()
|
27 |
+
|
28 |
+
"""Take a screenshot of the current screen and return a ToolResult with the base64 encoded image."""
|
29 |
+
output_dir = Path(OUTPUT_DIR)
|
30 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
31 |
+
path = output_dir / f"screenshot_{uuid4().hex}.png"
|
32 |
+
|
33 |
+
ImageGrab.grab = partial(ImageGrab.grab, all_screens=True)
|
34 |
+
|
35 |
+
# Detect platform
|
36 |
+
system = platform.system()
|
37 |
+
|
38 |
+
if system == "Windows":
|
39 |
+
# Windows: Use screeninfo to get monitor details
|
40 |
+
screens = get_monitors()
|
41 |
+
|
42 |
+
# Sort screens by x position to arrange from left to right
|
43 |
+
sorted_screens = sorted(screens, key=lambda s: s.x)
|
44 |
+
|
45 |
+
if selected_screen < 0 or selected_screen >= len(screens):
|
46 |
+
raise IndexError("Invalid screen index.")
|
47 |
+
|
48 |
+
screen = sorted_screens[selected_screen]
|
49 |
+
bbox = (screen.x, screen.y, screen.x + screen.width, screen.y + screen.height)
|
50 |
+
|
51 |
+
elif system == "Darwin": # macOS
|
52 |
+
# macOS: Use Quartz to get monitor details
|
53 |
+
max_displays = 32 # Maximum number of displays to handle
|
54 |
+
active_displays = Quartz.CGGetActiveDisplayList(max_displays, None, None)[1]
|
55 |
+
|
56 |
+
# Get the display bounds (resolution) for each active display
|
57 |
+
screens = []
|
58 |
+
for display_id in active_displays:
|
59 |
+
bounds = Quartz.CGDisplayBounds(display_id)
|
60 |
+
screens.append({
|
61 |
+
'id': display_id,
|
62 |
+
'x': int(bounds.origin.x),
|
63 |
+
'y': int(bounds.origin.y),
|
64 |
+
'width': int(bounds.size.width),
|
65 |
+
'height': int(bounds.size.height),
|
66 |
+
'is_primary': Quartz.CGDisplayIsMain(display_id) # Check if this is the primary display
|
67 |
+
})
|
68 |
+
|
69 |
+
# Sort screens by x position to arrange from left to right
|
70 |
+
sorted_screens = sorted(screens, key=lambda s: s['x'])
|
71 |
+
# print(f"Darwin sorted_screens: {sorted_screens}")
|
72 |
+
|
73 |
+
if selected_screen < 0 or selected_screen >= len(screens):
|
74 |
+
raise IndexError("Invalid screen index.")
|
75 |
+
|
76 |
+
screen = sorted_screens[selected_screen]
|
77 |
+
|
78 |
+
bbox = (screen['x'], screen['y'], screen['x'] + screen['width'], screen['y'] + screen['height'])
|
79 |
+
|
80 |
+
else: # Linux or other OS
|
81 |
+
cmd = "xrandr | grep ' primary' | awk '{print $4}'"
|
82 |
+
try:
|
83 |
+
output = subprocess.check_output(cmd, shell=True).decode()
|
84 |
+
resolution = output.strip().split()[0]
|
85 |
+
width, height = map(int, resolution.split('x'))
|
86 |
+
bbox = (0, 0, width, height) # Assuming single primary screen for simplicity
|
87 |
+
except subprocess.CalledProcessError:
|
88 |
+
raise RuntimeError("Failed to get screen resolution on Linux.")
|
89 |
+
|
90 |
+
# Take screenshot using the bounding box
|
91 |
+
screenshot = ImageGrab.grab(bbox=bbox)
|
92 |
+
|
93 |
+
# Set offsets (for potential future use)
|
94 |
+
offset_x = screen['x'] if system == "Darwin" else screen.x
|
95 |
+
offset_y = screen['y'] if system == "Darwin" else screen.y
|
96 |
+
|
97 |
+
# # Resize if
|
98 |
+
if resize:
|
99 |
+
screenshot = screenshot.resize((target_width, target_height))
|
100 |
+
|
101 |
+
# Save the screenshot
|
102 |
+
screenshot.save(str(path))
|
103 |
+
|
104 |
+
if path.exists():
|
105 |
+
# Return a ToolResult instance instead of a dictionary
|
106 |
+
return screenshot, path
|
107 |
+
|
108 |
+
raise ToolError(f"Failed to take screenshot: {path} does not exist.")
|
109 |
+
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
def _get_screen_size(selected_screen: int = 0):
|
114 |
+
if platform.system() == "Windows":
|
115 |
+
# Use screeninfo to get primary monitor on Windows
|
116 |
+
screens = get_monitors()
|
117 |
+
|
118 |
+
# Sort screens by x position to arrange from left to right
|
119 |
+
sorted_screens = sorted(screens, key=lambda s: s.x)
|
120 |
+
if selected_screen is None:
|
121 |
+
primary_monitor = next((m for m in get_monitors() if m.is_primary), None)
|
122 |
+
return primary_monitor.width, primary_monitor.height
|
123 |
+
elif selected_screen < 0 or selected_screen >= len(screens):
|
124 |
+
raise IndexError("Invalid screen index.")
|
125 |
+
else:
|
126 |
+
screen = sorted_screens[selected_screen]
|
127 |
+
return screen.width, screen.height
|
128 |
+
elif platform.system() == "Darwin":
|
129 |
+
# macOS part using Quartz to get screen information
|
130 |
+
max_displays = 32 # Maximum number of displays to handle
|
131 |
+
active_displays = Quartz.CGGetActiveDisplayList(max_displays, None, None)[1]
|
132 |
+
|
133 |
+
# Get the display bounds (resolution) for each active display
|
134 |
+
screens = []
|
135 |
+
for display_id in active_displays:
|
136 |
+
bounds = Quartz.CGDisplayBounds(display_id)
|
137 |
+
screens.append({
|
138 |
+
'id': display_id,
|
139 |
+
'x': int(bounds.origin.x),
|
140 |
+
'y': int(bounds.origin.y),
|
141 |
+
'width': int(bounds.size.width),
|
142 |
+
'height': int(bounds.size.height),
|
143 |
+
'is_primary': Quartz.CGDisplayIsMain(display_id) # Check if this is the primary display
|
144 |
+
})
|
145 |
+
|
146 |
+
# Sort screens by x position to arrange from left to right
|
147 |
+
sorted_screens = sorted(screens, key=lambda s: s['x'])
|
148 |
+
|
149 |
+
if selected_screen is None:
|
150 |
+
# Find the primary monitor
|
151 |
+
primary_monitor = next((screen for screen in screens if screen['is_primary']), None)
|
152 |
+
if primary_monitor:
|
153 |
+
return primary_monitor['width'], primary_monitor['height']
|
154 |
+
else:
|
155 |
+
raise RuntimeError("No primary monitor found.")
|
156 |
+
elif selected_screen < 0 or selected_screen >= len(screens):
|
157 |
+
raise IndexError("Invalid screen index.")
|
158 |
+
else:
|
159 |
+
# Return the resolution of the selected screen
|
160 |
+
screen = sorted_screens[selected_screen]
|
161 |
+
return screen['width'], screen['height']
|
162 |
+
|
163 |
+
else: # Linux or other OS
|
164 |
+
cmd = "xrandr | grep ' primary' | awk '{print $4}'"
|
165 |
+
try:
|
166 |
+
output = subprocess.check_output(cmd, shell=True).decode()
|
167 |
+
resolution = output.strip().split()[0]
|
168 |
+
width, height = map(int, resolution.split('x'))
|
169 |
+
return width, height
|
170 |
+
except subprocess.CalledProcessError:
|
171 |
+
raise RuntimeError("Failed to get screen resolution on Linux.")
|
dev-requirements.txt
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ruff==0.6.7
|
2 |
+
pre-commit==3.8.0
|
3 |
+
pytest==8.3.3
|
4 |
+
pytest-asyncio==0.23.6
|
5 |
+
pyautogui==0.9.54
|
6 |
+
streamlit>=1.38.0
|
7 |
+
anthropic[bedrock,vertex]>=0.37.1
|
8 |
+
jsonschema==4.22.0
|
9 |
+
boto3>=1.28.57
|
10 |
+
google-auth<3,>=2
|
11 |
+
gradio>=5.6.0
|
12 |
+
screeninfo
|
13 |
+
uiautomation
|
14 |
+
|
15 |
+
# make sure to install the correct version of torch (cuda, mps, cpu, etc.)
|
16 |
+
torch
|
17 |
+
torchvision
|
18 |
+
|
19 |
+
transformers
|
20 |
+
qwen-vl-utils
|
21 |
+
accelerate
|
22 |
+
dashscope
|
23 |
+
huggingface_hub
|
docs/README_cn.md
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<h2 align="center">
|
2 |
+
<a href="https://computer-use-ootb.github.io">
|
3 |
+
<img src="../assets/ootb_logo.png" alt="Logo" style="display: block; margin: 0 auto; filter: invert(1) brightness(2);">
|
4 |
+
</a>
|
5 |
+
</h2>
|
6 |
+
|
7 |
+
|
8 |
+
<h5 align="center"> 如果你喜欢我们的项目,请在GitHub上为我们加星⭐以获取最新更新。</h5>
|
9 |
+
|
10 |
+
<h5 align=center>
|
11 |
+
|
12 |
+
[](https://arxiv.org/abs/2411.10323)
|
13 |
+
[](https://computer-use-ootb.github.io)
|
14 |
+
[](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2Fshowlab%2Fcomputer_use_ootb&count_bg=%2379C83D&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=hits&edge_flat=false)
|
15 |
+
|
16 |
+
|
17 |
+
</h5>
|
18 |
+
|
19 |
+
## <img src="../assets/ootb_icon.png" alt="Star" style="height:25px; vertical-align:middle; filter: invert(1) brightness(2);"> 概览
|
20 |
+
**Computer Use <span style="color:rgb(106, 158, 210)">O</span><span style="color:rgb(111, 163, 82)">O</span><span style="color:rgb(209, 100, 94)">T</span><span style="color:rgb(238, 171, 106)">B</span>**<img src="../assets/ootb_icon.png" alt="Star" style="height:20px; vertical-align:middle; filter: invert(1) brightness(2);"> 是一个桌面GUI Agent的开箱即用(OOTB)解决方案,包括API支持的 (**Claude 3.5 Computer Use**) 和本地运行的模型 (**<span style="color:rgb(106, 158, 210)">S</span><span style="color:rgb(111, 163, 82)">h</span><span style="color:rgb(209, 100, 94)">o</span><span style="color:rgb(238, 171, 106)">w</span>UI**)。
|
21 |
+
|
22 |
+
**无需Docker**,支持 **Windows** 和 **macOS**。本项目提供了一个基于Gradio的用户友好界面。🎨
|
23 |
+
|
24 |
+
想了解更多信息,请访问我们关于Claude 3.5 Computer Use的研究 [[项目页面]](https://computer-use-ootb.github.io)。🌐
|
25 |
+
|
26 |
+
## 更新
|
27 |
+
- **<span style="color:rgb(231, 183, 98)">重大更新!</span> [2024/12/04]** **本地运行🔥** 已上线!欢迎使用 [**<span style="color:rgb(106, 158, 210)">S</span><span style="color:rgb(111, 163, 82)">h</span><span style="color:rgb(209, 100, 94)">o</span><span style="color:rgb(238, 171, 106)">w</span>UI**](https://github.com/showlab/ShowUI),一个开源的2B视觉-语言-动作(VLA)模型作为GUI Agent。现在可兼容 `"gpt-4o + ShowUI" (~便宜200倍)`* 及 `"Qwen2-VL + ShowUI" (~便宜30倍)`*,只需几美分💰! <span style="color: grey; font-size: small;">*与Claude Computer Use相比</span>。
|
28 |
+
- **[2024/11/20]** 我们添加了一些示例来帮助你上手Claude 3.5 Computer Use。
|
29 |
+
- **[2024/11/19]** 不再受Anthropic单显示器限制——现在你可以使用 **多显示器** 🎉!
|
30 |
+
- **[2024/11/18]** 我们发布了Claude 3.5 Computer Use的深度分析: [https://arxiv.org/abs/2411.10323](https://arxiv.org/abs/2411.10323)。
|
31 |
+
- **[2024/11/11]** 不再受Anthropic低分辨率显示限制——你可以使用 *任意分辨率* 同时保持 **截图token成本较低** 🎉!
|
32 |
+
- **[2024/11/11]** 现在 **Windows** 和 **macOS** 两个平台均已支持 🎉!
|
33 |
+
- **[2024/10/25]** 现在你可以通过手机设备 📱 **远程控制** 你的电脑 💻——**无需在手机上安装APP**!试试吧,玩得开心 🎉。
|
34 |
+
|
35 |
+
## 演示视频
|
36 |
+
|
37 |
+
https://github.com/user-attachments/assets/f50b7611-2350-4712-af9e-3d31e30020ee
|
38 |
+
|
39 |
+
<div style="display: flex; justify-content: space-around;">
|
40 |
+
<a href="https://youtu.be/Ychd-t24HZw" target="_blank" style="margin-right: 10px;">
|
41 |
+
<img src="https://img.youtube.com/vi/Ychd-t24HZw/maxresdefault.jpg" alt="Watch the video" width="48%">
|
42 |
+
</a>
|
43 |
+
<a href="https://youtu.be/cvgPBazxLFM" target="_blank">
|
44 |
+
<img src="https://img.youtube.com/vi/cvgPBazxLFM/maxresdefault.jpg" alt="Watch the video" width="48%">
|
45 |
+
</a>
|
46 |
+
</div>
|
47 |
+
|
48 |
+
|
49 |
+
## 🚀 开始使用
|
50 |
+
|
51 |
+
### 0. 前置条件
|
52 |
+
- 请通过此[链接](https://www.anaconda.com/download?utm_source=anacondadocs&utm_medium=documentation&utm_campaign=download&utm_content=topnavalldocs)安装 Miniconda。(**Python版本:≥3.11**)
|
53 |
+
- 硬件要求(可选,针对ShowUI本地运行):
|
54 |
+
- **Windows (支持CUDA)**: 有CUDA支持的NVIDIA GPU,GPU显存≥6GB
|
55 |
+
- **macOS (Apple Silicon)**: M1芯片(或更新),统一RAM≥16GB
|
56 |
+
|
57 |
+
|
58 |
+
### 1. 克隆仓库 📂
|
59 |
+
打开Conda终端。(安装Miniconda后,将在开始菜单出现)
|
60 |
+
在 **Conda终端** 中运行以下命令:
|
61 |
+
```bash
|
62 |
+
git clone https://github.com/showlab/computer_use_ootb.git
|
63 |
+
cd computer_use_ootb
|
64 |
+
```
|
65 |
+
|
66 |
+
### 2.1 安装依赖 🔧
|
67 |
+
```
|
68 |
+
pip install -r dev-requirements.txt
|
69 |
+
```
|
70 |
+
|
71 |
+
### 2.2 (可选)为 **<span style="color:rgb(106, 158, 210)">S</span><span style="color:rgb(111, 163, 82)">h</span><span style="color:rgb(209, 100, 94)">o</span><span style="color:rgb(238, 171, 106)">w</span>UI** 本地运行做准备
|
72 |
+
|
73 |
+
1. 使用以下命令下载 ShowUI-2B 模型的所有文件。确保 ShowUI-2B 文件夹位于 computer_use_ootb 文件夹下。
|
74 |
+
|
75 |
+
|
76 |
+
```
|
77 |
+
python install_showui.py
|
78 |
+
```
|
79 |
+
|
80 |
+
|
81 |
+
2. 在您的机器上安装正确的 GPU 版 PyTorch(CUDA、MPS 等)。请参考 [安装指南与验证](https://pytorch.org/get-started/locally/)。
|
82 |
+
|
83 |
+
3. 获取 [GPT-4o](https://platform.openai.com/docs/quickstart) 或 [Qwen-VL](https://help.aliyun.com/zh/dashscope/developer-reference/acquisition-and-configuration-of-api-key) 的 API Key。对于中国大陆用户,可享受 Qwen API 免费试用 100 万token:[点击查看](https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-qianwen-vl-plus-api)。
|
84 |
+
|
85 |
+
### 3. 启动界面 ▶️
|
86 |
+
|
87 |
+
**启动 OOTB 界面:**
|
88 |
+
```
|
89 |
+
python app.py
|
90 |
+
```
|
91 |
+
|
92 |
+
若成功启动界面,您将在终端中看到两个 URL:
|
93 |
+
```
|
94 |
+
* Running on local URL: http://127.0.0.1:7860
|
95 |
+
* Running on public URL: https://xxxxxxxxxxxxxxxx.gradio.live (请勿与他人分享此链接,否则他们可控制您的电脑。)
|
96 |
+
```
|
97 |
+
|
98 |
+
|
99 |
+
> <u>为方便起见</u>,我们推荐在启动界面前运行以下命令,将 API 密钥设置为环境变量。这样您无需在每次运行时手动输入。
|
100 |
+
在 Windows Powershell 中(如在 cmd 中则使用 set 命令):
|
101 |
+
>
|
102 |
+
```
|
103 |
+
$env:ANTHROPIC_API_KEY="sk-xxxxx" (替换为您的密钥)
|
104 |
+
$env:QWEN_API_KEY="sk-xxxxx"
|
105 |
+
$env:OPENAI_API_KEY="sk-xxxxx"
|
106 |
+
```
|
107 |
+
|
108 |
+
> 在 macOS/Linux 中,将上述命令中的 $env:ANTHROPIC_API_KEY 替换为 export ANTHROPIC_API_KEY 即可。
|
109 |
+
|
110 |
+
|
111 |
+
### 4. 使用任意可访问网络的设备控制您的电脑
|
112 |
+
- **待控制的电脑**:安装了上述软件的那台电脑。
|
113 |
+
- **发送指令的设备**:打开网址的任意设备。
|
114 |
+
|
115 |
+
在本机浏览器中打开 http://localhost:7860/(若在本机控制)或在您的手机浏览器中打开 https://xxxxxxxxxxxxxxxxx.gradio.live(若远程控制)。
|
116 |
+
|
117 |
+
输入 Anthropic API 密钥(可通过[此页面](https://console.anthropic.com/settings/keys)获取),然后给出指令让 AI 执行任务。
|
118 |
+
|
119 |
+
<div style="display: flex; align-items: center; gap: 10px;">
|
120 |
+
<figure style="text-align: center;">
|
121 |
+
<img src="./assets/gradio_interface.png" alt="Desktop Interface" style="width: auto; object-fit: contain;">
|
122 |
+
</figure>
|
123 |
+
</div>
|
124 |
+
|
125 |
+
|
126 |
+
|
127 |
+
## 🖥️ 支持的系统
|
128 |
+
- **Windows** (Claude ✅, ShowUI ✅)
|
129 |
+
- **macOS** (Claude ✅, ShowUI ✅)
|
130 |
+
|
131 |
+
## ⚠️ 风险
|
132 |
+
- **模型可能执行危险操作**:模型仍有局限性,可能生成非预期或潜在有害的输出。建议持续监督 AI 的操作。
|
133 |
+
- **成本控制**:每个任务可能花费几美元(Claude 3.5 Computer Use)。💸
|
134 |
+
|
135 |
+
## 📅 路线图
|
136 |
+
- [ ] **探索可用功能**
|
137 |
+
- [ ] Claude API 在解决任务时似乎不稳定。我们正在调查原因:分辨率、操作类型、操作系统平台或规划机制等。欢迎提出想法或评论。
|
138 |
+
- [ ] **界面设计**
|
139 |
+
- [x] **支持 Gradio** ✨
|
140 |
+
- [ ] **更简单的安装流程**
|
141 |
+
- [ ] **更多特性**... 🚀
|
142 |
+
- [ ] **平台**
|
143 |
+
- [x] **Windows**
|
144 |
+
- [x] **移动端**(发出指令)
|
145 |
+
- [x] **macOS**
|
146 |
+
- [ ] **移动端**(被控制)
|
147 |
+
- [ ] **支持更多多模态大模型(MLLMs)**
|
148 |
+
- [x] **Claude 3.5 Sonnet** 🎵
|
149 |
+
- [x] **GPT-4o**
|
150 |
+
- [x] **Qwen2-VL**
|
151 |
+
- [ ] ...
|
152 |
+
- [ ] **改进提示策略**
|
153 |
+
- [ ] 优化提示以降低成本。💡
|
154 |
+
- [ ] **提升推理速度**
|
155 |
+
- [ ] 支持 int8 量化。
|
156 |
+
|
157 |
+
## 加入讨论
|
158 |
+
欢迎加入讨论,与我们一同不断改进 Computer Use - OOTB 的用户体验。可通过 [**Discord 频道**](https://discord.gg/HnHng5de) 或下方微信二维码联系我们!
|
159 |
+
|
160 |
+
<div style="display: flex; flex-direction: row; justify-content: space-around;">
|
161 |
+
|
162 |
+
<img src="../assets/wechat_2.jpg" alt="gradio_interface" width="30%">
|
163 |
+
<img src="../assets/wechat.jpg" alt="gradio_interface" width="30%">
|
164 |
+
|
165 |
+
</div>
|
166 |
+
|
167 |
+
<div style="height: 30px;"></div>
|
168 |
+
|
169 |
+
<hr>
|
170 |
+
<a href="https://computer-use-ootb.github.io">
|
171 |
+
<img src="../assets/ootb_logo.png" alt="Logo" width="30%" style="display: block; margin: 0 auto; filter: invert(1) brightness(2);">
|
172 |
+
</a>
|
install_tools/install_showui-awq-4bit.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from huggingface_hub import hf_hub_download, list_repo_files
|
3 |
+
|
4 |
+
# Specify the model repository and destination folder
|
5 |
+
model_repo = "yyyang/showui-2b-awq"
|
6 |
+
destination_folder = "./showui-2b-awq"
|
7 |
+
|
8 |
+
# Ensure the destination folder exists
|
9 |
+
os.makedirs(destination_folder, exist_ok=True)
|
10 |
+
|
11 |
+
# List all files in the repository
|
12 |
+
files = list_repo_files(repo_id=model_repo)
|
13 |
+
|
14 |
+
# Download each file to the destination folder
|
15 |
+
for file in files:
|
16 |
+
file_path = hf_hub_download(repo_id=model_repo, filename=file, local_dir=destination_folder)
|
17 |
+
print(f"Downloaded {file} to {file_path}")
|
install_tools/install_showui.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from huggingface_hub import hf_hub_download, list_repo_files
|
3 |
+
|
4 |
+
# Specify the model repository and destination folder
|
5 |
+
model_repo = "showlab/ShowUI-2B"
|
6 |
+
destination_folder = "./showui-2b"
|
7 |
+
|
8 |
+
# Ensure the destination folder exists
|
9 |
+
os.makedirs(destination_folder, exist_ok=True)
|
10 |
+
|
11 |
+
# List all files in the repository
|
12 |
+
files = list_repo_files(repo_id=model_repo)
|
13 |
+
|
14 |
+
# Download each file to the destination folder
|
15 |
+
for file in files:
|
16 |
+
file_path = hf_hub_download(repo_id=model_repo, filename=file, local_dir=destination_folder)
|
17 |
+
print(f"Downloaded {file} to {file_path}")
|
install_tools/install_uitars-2b-sft.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from huggingface_hub import hf_hub_download, list_repo_files
|
3 |
+
|
4 |
+
# Specify the model repository and destination folder
|
5 |
+
model_repo = "bytedance-research/UI-TARS-2B-SFT"
|
6 |
+
destination_folder = "./ui-tars-2b-sft"
|
7 |
+
|
8 |
+
# Ensure the destination folder exists
|
9 |
+
os.makedirs(destination_folder, exist_ok=True)
|
10 |
+
|
11 |
+
# List all files in the repository
|
12 |
+
files = list_repo_files(repo_id=model_repo)
|
13 |
+
|
14 |
+
# Download each file to the destination folder
|
15 |
+
for file in files:
|
16 |
+
file_path = hf_hub_download(repo_id=model_repo, filename=file, local_dir=destination_folder)
|
17 |
+
print(f"Downloaded {file} to {file_path}")
|
install_tools/test_ui-tars_server.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from openai import OpenAI
|
2 |
+
from computer_use_demo.gui_agent.llm_utils.oai import encode_image
|
3 |
+
|
4 |
+
_NAV_SYSTEM_GROUNDING = """
|
5 |
+
You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
6 |
+
|
7 |
+
## Output Format
|
8 |
+
```Action: ...```
|
9 |
+
|
10 |
+
## Action Space
|
11 |
+
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
12 |
+
hotkey(key='')
|
13 |
+
type(content='') #If you want to submit your input, use \"\" at the end of `content`.
|
14 |
+
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
|
15 |
+
wait() #Sleep for 5s and take a screenshot to check for any changes.
|
16 |
+
finished()
|
17 |
+
call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help.
|
18 |
+
|
19 |
+
## Note
|
20 |
+
- Do not generate any other text.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def get_prompt_grounding(task):
|
24 |
+
return f"""{task}"""
|
25 |
+
|
26 |
+
task = """
|
27 |
+
```json
|
28 |
+
{{ "Observation": "I am on the google homepage of the Chrome browser.",
|
29 |
+
"Thinking": "The user wants to buy a lap-top on Amazon.com, so I need to click on the address (search) bar of Chrome for entering the 'Amazon.com'.",
|
30 |
+
"Next Action": ["I need to click DSML"],
|
31 |
+
"Expectation": "The search button is activated after being clicked, ready to input."
|
32 |
+
}}```
|
33 |
+
"""
|
34 |
+
|
35 |
+
task = """
|
36 |
+
```json
|
37 |
+
{{
|
38 |
+
"Observation": "I am on the google homepage of the Chrome browser.",
|
39 |
+
"Thinking": "The user wants to click DSML",
|
40 |
+
"Next Action": ["I need to click DSML"],
|
41 |
+
}}```
|
42 |
+
"""
|
43 |
+
|
44 |
+
task = """
|
45 |
+
```json
|
46 |
+
{{
|
47 |
+
"Observation": "I am on the google homepage of the Chrome browser.",
|
48 |
+
"Thinking": "The user wants to click Youtube",
|
49 |
+
"Next Action": ["I need to click Youtube"],
|
50 |
+
}}```
|
51 |
+
"""
|
52 |
+
|
53 |
+
if __name__ == "__main__":
|
54 |
+
|
55 |
+
ui_tars_url = "https://your_api_to_uitars.com/v1"
|
56 |
+
ui_tars_client = OpenAI(base_url=ui_tars_url, api_key="")
|
57 |
+
grounding_system_prompt = _NAV_SYSTEM_GROUNDING.format()
|
58 |
+
screenshot_base64 = encode_image("./chrome.png")
|
59 |
+
prompted_message = get_prompt_grounding(task)
|
60 |
+
|
61 |
+
print(f"grounding_system_prompt, {grounding_system_prompt}, \
|
62 |
+
prompted_message: {prompted_message}")
|
63 |
+
|
64 |
+
response = ui_tars_client.chat.completions.create(
|
65 |
+
model="ui-tars",
|
66 |
+
messages=[
|
67 |
+
{"role": "user", "content": grounding_system_prompt},
|
68 |
+
{"role": "user", "content": [
|
69 |
+
{"type": "text", "text": prompted_message},
|
70 |
+
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{screenshot_base64}"}}
|
71 |
+
]
|
72 |
+
},
|
73 |
+
],
|
74 |
+
max_tokens=128,
|
75 |
+
temperature=0
|
76 |
+
)
|
77 |
+
|
78 |
+
ui_tars_action = response.choices[0].message.content
|
79 |
+
|
80 |
+
print(response.choices[0].message.content)
|
81 |
+
|
82 |
+
|