import torch import gradio as gr from transformers import AutoModelForCausalLM, AutoProcessor from PIL import ImageDraw device = "cuda:0" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 models = { "AskUI/PTA-1": AutoModelForCausalLM.from_pretrained("AskUI/PTA-1", trust_remote_code=True), } processors = { "AskUI/PTA-1": AutoProcessor.from_pretrained("AskUI/PTA-1", trust_remote_code=True) } def draw_bounding_boxes(image, bounding_boxes, outline_color="red", line_width=3): draw = ImageDraw.Draw(image) for box in bounding_boxes: xmin, ymin, xmax, ymax = box draw.rectangle([xmin, ymin, xmax, ymax], outline=outline_color, width=line_width) return image def florence_output_to_box(output): try: if "polygons" in output and len(output["polygons"]) > 0: polygons = output["polygons"] target_polygon = polygons[0][0] target_polygon = [int(el) for el in target_polygon] return [ target_polygon[0], target_polygon[1], target_polygon[4], target_polygon[5], ] if "bboxes" in output and len(output["bboxes"]) > 0: bboxes = output["bboxes"] target_bbox = bboxes[0] target_bbox = [int(el) for el in target_bbox] return target_bbox except Exception as e: print(f"Error: {e}") return None def run_example(image, text_input, model_id="AskUI/PTA-1"): model = models[model_id].to(device, torch_dtype) processor = processors[model_id] task_prompt = "" prompt = task_prompt + text_input image = image.convert("RGB") inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype) generated_ids = model.generate( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, do_sample=False, num_beams=3, ) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] parsed_answer = processor.post_process_generation(generated_text, task="", image_size=(image.width, image.height)) target_box = florence_output_to_box(parsed_answer[""]) return target_box, draw_bounding_boxes(image, [target_box]) css = """ #output { height: 500px; overflow: auto; border: 1px solid #ccc; } """ with gr.Blocks(css=css) as demo: gr.Markdown( """ # PTA-1: Controlling Computers with Small Models """) gr.Markdown("Check out the model [AskUI/PTA-1](https://huggingface.co/AskUI/PTA-1).") with gr.Row(): with gr.Column(): input_img = gr.Image(label="Input Image", type="pil") model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value="AskUI/PTA-1") text_input = gr.Textbox(label="User Prompt") submit_btn = gr.Button(value="Submit") with gr.Column(): model_output_text = gr.Textbox(label="Model Output Text") annotated_image = gr.Image(label="Annotated Image") gr.Examples( examples=[ ["assets/sample.png", "search box"], ["assets/sample.png", "Query Service"], ["assets/ipad.png", "App Store icon"], ["assets/ipad.png", 'colorful icon with letter "S"'], ["assets/phone.jpg", "password field"], ["assets/phone.jpg", "back arrow icon"], ["assets/windows.jpg", "icon with letter S"], ["assets/windows.jpg", "Settings"], ], inputs=[input_img, text_input], outputs=[model_output_text, annotated_image], fn=run_example, cache_examples=False, label="Try examples" ) submit_btn.click(run_example, [input_img, text_input, model_selector], [model_output_text, annotated_image]) demo.launch(debug=False, ssr_mode=False)