"""VIP.""" import json import re import cv2 from tqdm import trange import vip def make_prompt(description, top_n=3): return f""" INSTRUCTIONS: You are tasked to locate an object, region, or point in space in the given annotated image according to a description. The image is annoated with numbered circles. Choose the top {top_n} circles that have the most overlap with and/or is closest to what the description is describing in the image. You are a five-time world champion in this game. Give a one sentence analysis of why you chose those points. Provide your answer at the end in a valid JSON of this format: {{"points": []}} DESCRIPTION: {description} IMAGE: """.strip() def extract_json(response, key): json_part = re.search(r"\{.*\}", response, re.DOTALL) parsed_json = {} if json_part: json_data = json_part.group() # Parse the JSON data parsed_json = json.loads(json_data) else: print("No JSON data found ******\n", response) return parsed_json[key] def vip_perform_selection(prompter, vlm, im, desc, arm_coord, samples, top_n): """Perform one selection pass given samples.""" image_circles_np = prompter.add_arrow_overlay_plt( image=im, samples=samples, arm_xy=arm_coord, log_image=False ) _, encoded_image_circles = cv2.imencode(".png", image_circles_np) prompt_seq = [make_prompt(desc, top_n=top_n), encoded_image_circles] response = vlm.query(prompt_seq) arrow_ids = extract_json(response, "points") return arrow_ids, image_circles_np def vip_runner( vlm, im, desc, style, action_spec, n_samples_init=25, n_samples_opt=10, n_iters=3, recursion_level=0, ): """VIP.""" prompter = vip.VisualIterativePrompter( style, action_spec, vip.SupportedEmbodiments.META_NAVIGATION ) output_ims = [] arm_coord = (int(im.shape[1] / 2), int(im.shape[0] / 2)) if recursion_level == 0: center_mean = action_spec["loc"] center_std = action_spec["scale"] selected_samples = [] for itr in trange(n_iters): if itr == 0: style["num_samples"] = n_samples_init else: style["num_samples"] = n_samples_opt samples = prompter.sample_actions(im, arm_coord, center_mean, center_std) arrow_ids, image_circles_np = vip_perform_selection( prompter, vlm, im, desc, arm_coord, samples, top_n=3 ) # plot sampled circles as red selected_samples = [] for selected_id in arrow_ids: sample = samples[selected_id] sample.coord.color = (255, 0, 0) selected_samples.append(sample) image_circles_marked_np = prompter.add_arrow_overlay_plt( image_circles_np, selected_samples, arm_coord ) output_ims.append(image_circles_marked_np) # if at last iteration, pick one answer out of the selected ones if itr == n_iters - 1: arrow_ids, _ = vip_perform_selection( prompter, vlm, im, desc, arm_coord, selected_samples, top_n=1 ) selected_samples = [] for selected_id in arrow_ids: sample = samples[selected_id] sample.coord.color = (255, 0, 0) selected_samples.append(sample) image_circles_marked_np = prompter.add_arrow_overlay_plt( im, selected_samples, arm_coord ) output_ims.append(image_circles_marked_np) center_mean, center_std = prompter.fit(arrow_ids, samples) if output_ims: return ( output_ims, prompter.action_to_coord(center_mean, im, arm_coord).xy, selected_samples, ) else: new_samples = [] for i in range(3): out_ims, _, cur_samples = vip_runner( vlm=vlm, im=im, desc=desc, style=style, action_spec=action_spec, n_samples_init=n_samples_init, n_samples_opt=n_samples_opt, n_iters=n_iters, recursion_level=recursion_level - 1, ) output_ims += out_ims new_samples += cur_samples # adjust sample label to avoid duplications for sample_id in range(len(new_samples)): new_samples[sample_id].label = str(sample_id) arrow_ids, _ = vip_perform_selection( prompter, vlm, im, desc, arm_coord, new_samples, top_n=1 ) selected_samples = [] for selected_id in arrow_ids: sample = new_samples[selected_id] sample.coord.color = (255, 0, 0) selected_samples.append(sample) image_circles_marked_np = prompter.add_arrow_overlay_plt( im, selected_samples, arm_coord ) output_ims.append(image_circles_marked_np) center_mean, _ = prompter.fit(arrow_ids, new_samples) if output_ims: return ( output_ims, prompter.action_to_coord(center_mean, im, arm_coord).xy, selected_samples, ) return [], "Unable to understand query"