|
from typing import Dict, List, Any |
|
from peft import AutoPeftModelForCausalLM |
|
import transformers |
|
import os |
|
import tempfile |
|
from PIL import Image, ImageDraw |
|
from io import BytesIO |
|
import base64, json |
|
|
|
COORDINATE_PROMPT = 'In this UI screenshot, what is the position of the element corresponding to the command \"{command}\" (with point)?' |
|
|
|
PARTITION_PROMPT = 'In this UI screenshot, what is the partition of the element corresponding to the command \"{command}\" (with quadrant number)?' |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
self.model = transformers.AutoModelForCausalLM.from_pretrained( |
|
path, |
|
device_map="cuda", |
|
trust_remote_code=True, |
|
fp16=True).eval() |
|
self.tokenizer = transformers.AutoTokenizer.from_pretrained( |
|
path, |
|
cache_dir=None, |
|
model_max_length=2048, |
|
padding_side="right", |
|
use_fast=False, |
|
trust_remote_code=True, |
|
) |
|
self.tokenizer.pad_token_id = self.tokenizer.eod_id |
|
return |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
data args: |
|
image (:obj: `str`) |
|
task (:obj: `str`) |
|
k (:obj: `str`) |
|
context (:obj: 'str') |
|
kwargs |
|
Return: |
|
A :obj:`list` | `dict`: will be serialized and returned |
|
""" |
|
|
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
image = os.path.join(temp_dir, "image.png") |
|
img = Image.open(BytesIO(base64.b64decode(data["inputs"]["image"]))) |
|
img.save(image) |
|
command = data["inputs"]["task"] |
|
K = int(data["inputs"]["k"]) |
|
keep_context = bool(data["inputs"]["context"]) |
|
|
|
print(image) |
|
print(command) |
|
print(K) |
|
print(keep_context) |
|
|
|
images = [image] |
|
partitions = [] |
|
|
|
for k in range(K): |
|
query = self.tokenizer.from_list_format(([{ 'image': context_image } for context_image in images] if keep_context else [{'image': image}]) + |
|
[{'text': PARTITION_PROMPT.format(command=command)}]) |
|
response, _ = self.model.chat(self.tokenizer, query=query, history=None) |
|
|
|
partition = int(response.split(" ")[-1]) |
|
partitions.append(partition) |
|
|
|
|
|
with Image.open(image) as img: |
|
width, height = img.size |
|
if partition == 1: |
|
img = img.crop((width // 2, 0, width, height // 2)) |
|
elif partition == 2: |
|
img = img.crop((0, 0, width // 2, height // 2)) |
|
elif partition == 3: |
|
img = img.crop((0, height // 2, width // 2, height)) |
|
elif partition == 4: |
|
img = img.crop((width // 2, height // 2, width, height)) |
|
|
|
new_path = os.path.join(temp_dir, f"partition{k}.png") |
|
img.save(new_path) |
|
image = new_path |
|
images.append(image) |
|
|
|
query = self.tokenizer.from_list_format(([{ 'image': context_image } for context_image in images] if keep_context else [{'image': image}]) + |
|
[{'text': COORDINATE_PROMPT.format(command=command)}]) |
|
response, _ = self.model.chat(self.tokenizer, query=query, history=None) |
|
print("Coordinate Response:", response) |
|
|
|
x = float(response.split(",")[0].split("(")[1]) |
|
y = float(response.split(",")[1].split(")")[0]) |
|
|
|
for partition in partitions[::-1]: |
|
if partition == 1: |
|
x = x/2 + 0.5 |
|
y = y/2 |
|
elif partition == 2: |
|
x = x/2 |
|
y = y/2 |
|
elif partition == 3: |
|
x = x/2 |
|
y = y/2 + 0.5 |
|
elif partition == 4: |
|
x = x/2 + 0.5 |
|
y = y/2 + 0.5 |
|
print("rescaled point:", x, y) |
|
|
|
response = {} |
|
response['x'] = x |
|
response['y'] = y |
|
return response |
|
|