k_2_model / handler.py
AmaadMartin's picture
Upload handler.py
4b95965 verified
from typing import Dict, List, Any
from peft import AutoPeftModelForCausalLM
import transformers
import os
import tempfile
from PIL import Image, ImageDraw
from io import BytesIO
import base64, json
COORDINATE_PROMPT = 'In this UI screenshot, what is the position of the element corresponding to the command \"{command}\" (with point)?'
PARTITION_PROMPT = 'In this UI screenshot, what is the partition of the element corresponding to the command \"{command}\" (with quadrant number)?'
class EndpointHandler():
def __init__(self, path=""):
self.model = transformers.AutoModelForCausalLM.from_pretrained(
path,
device_map="cuda",
trust_remote_code=True,
fp16=True).eval()
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
path,
cache_dir=None,
model_max_length=2048,
padding_side="right",
use_fast=False,
trust_remote_code=True,
)
self.tokenizer.pad_token_id = self.tokenizer.eod_id
return
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
image (:obj: `str`)
task (:obj: `str`)
k (:obj: `str`)
context (:obj: 'str')
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# open temp directory
with tempfile.TemporaryDirectory() as temp_dir:
image = os.path.join(temp_dir, "image.png")
img = Image.open(BytesIO(base64.b64decode(data["inputs"]["image"])))
img.save(image)
command = data["inputs"]["task"]
K = int(data["inputs"]["k"])
keep_context = bool(data["inputs"]["context"])
print(image)
print(command)
print(K)
print(keep_context)
images = [image]
partitions = []
for k in range(K):
query = self.tokenizer.from_list_format(([{ 'image': context_image } for context_image in images] if keep_context else [{'image': image}]) +
[{'text': PARTITION_PROMPT.format(command=command)}])
response, _ = self.model.chat(self.tokenizer, query=query, history=None)
partition = int(response.split(" ")[-1])
partitions.append(partition)
# get cropped image of the partition
with Image.open(image) as img:
width, height = img.size
if partition == 1:
img = img.crop((width // 2, 0, width, height // 2))
elif partition == 2:
img = img.crop((0, 0, width // 2, height // 2))
elif partition == 3:
img = img.crop((0, height // 2, width // 2, height))
elif partition == 4:
img = img.crop((width // 2, height // 2, width, height))
new_path = os.path.join(temp_dir, f"partition{k}.png")
img.save(new_path)
image = new_path
images.append(image)
query = self.tokenizer.from_list_format(([{ 'image': context_image } for context_image in images] if keep_context else [{'image': image}]) +
[{'text': COORDINATE_PROMPT.format(command=command)}])
response, _ = self.model.chat(self.tokenizer, query=query, history=None)
print("Coordinate Response:", response)
x = float(response.split(",")[0].split("(")[1])
y = float(response.split(",")[1].split(")")[0])
for partition in partitions[::-1]:
if partition == 1:
x = x/2 + 0.5
y = y/2
elif partition == 2:
x = x/2
y = y/2
elif partition == 3:
x = x/2
y = y/2 + 0.5
elif partition == 4:
x = x/2 + 0.5
y = y/2 + 0.5
print("rescaled point:", x, y)
response = {}
response['x'] = x
response['y'] = y
return response