File size: 4,873 Bytes
5c80958
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
"""VIP."""

import json
import re

import cv2
from tqdm import trange
import vip


def make_prompt(description, top_n=3):
  return f"""
INSTRUCTIONS:
You are tasked to locate an object, region, or point in space in the given annotated image according to a description.
The image is annoated with numbered circles.
Choose the top {top_n} circles that have the most overlap with and/or is closest to what the description is describing in the image.
You are a five-time world champion in this game. 
Give a one sentence analysis of why you chose those points.
Provide your answer at the end in a valid JSON of this format:

{{"points": []}}

DESCRIPTION: {description}
IMAGE:
""".strip()


def extract_json(response, key):
  json_part = re.search(r"\{.*\}", response, re.DOTALL)
  parsed_json = {}
  if json_part:
    json_data = json_part.group()
    # Parse the JSON data
    parsed_json = json.loads(json_data)
  else:
    print("No JSON data found ******\n", response)
  return parsed_json[key]


def vip_perform_selection(prompter, vlm, im, desc, arm_coord, samples, top_n):
  """Perform one selection pass given samples."""
  image_circles_np = prompter.add_arrow_overlay_plt(
      image=im, samples=samples, arm_xy=arm_coord, log_image=False
  )

  _, encoded_image_circles = cv2.imencode(".png", image_circles_np)

  prompt_seq = [make_prompt(desc, top_n=top_n), encoded_image_circles]
  response = vlm.query(prompt_seq)

  arrow_ids = extract_json(response, "points")
  return arrow_ids, image_circles_np


def vip_runner(
    vlm,
    im,
    desc,
    style,
    action_spec,
    n_samples_init=25,
    n_samples_opt=10,
    n_iters=3,
    recursion_level=0,
):
  """VIP."""

  prompter = vip.VisualIterativePrompter(
      style, action_spec, vip.SupportedEmbodiments.META_NAVIGATION
  )

  output_ims = []
  arm_coord = (int(im.shape[1] / 2), int(im.shape[0] / 2))

  if recursion_level == 0:
    center_mean = action_spec["loc"]
    center_std = action_spec["scale"]
    selected_samples = []
    for itr in trange(n_iters):
      if itr == 0:
        style["num_samples"] = n_samples_init
      else:
        style["num_samples"] = n_samples_opt
      samples = prompter.sample_actions(im, arm_coord, center_mean, center_std)
      arrow_ids, image_circles_np = vip_perform_selection(
          prompter, vlm, im, desc, arm_coord, samples, top_n=3
      )

      # plot sampled circles as red
      selected_samples = []
      for selected_id in arrow_ids:
        sample = samples[selected_id]
        sample.coord.color = (255, 0, 0)
        selected_samples.append(sample)
      image_circles_marked_np = prompter.add_arrow_overlay_plt(
          image_circles_np, selected_samples, arm_coord
      )
      output_ims.append(image_circles_marked_np)

      # if at last iteration, pick one answer out of the selected ones
      if itr == n_iters - 1:
        arrow_ids, _ = vip_perform_selection(
            prompter, vlm, im, desc, arm_coord, selected_samples, top_n=1
        )

        selected_samples = []
        for selected_id in arrow_ids:
          sample = samples[selected_id]
          sample.coord.color = (255, 0, 0)
          selected_samples.append(sample)
        image_circles_marked_np = prompter.add_arrow_overlay_plt(
            im, selected_samples, arm_coord
        )
        output_ims.append(image_circles_marked_np)
      center_mean, center_std = prompter.fit(arrow_ids, samples)

    if output_ims:
      return (
          output_ims,
          prompter.action_to_coord(center_mean, im, arm_coord).xy,
          selected_samples,
      )
  else:
    new_samples = []
    for i in range(3):
      out_ims, _, cur_samples = vip_runner(
          vlm=vlm,
          im=im,
          desc=desc,
          style=style,
          action_spec=action_spec,
          n_samples_init=n_samples_init,
          n_samples_opt=n_samples_opt,
          n_iters=n_iters,
          recursion_level=recursion_level - 1,
      )
      output_ims += out_ims
      new_samples += cur_samples
    # adjust sample label to avoid duplications
    for sample_id in range(len(new_samples)):
      new_samples[sample_id].label = str(sample_id)
    arrow_ids, _ = vip_perform_selection(
        prompter, vlm, im, desc, arm_coord, new_samples, top_n=1
    )

    selected_samples = []
    for selected_id in arrow_ids:
      sample = new_samples[selected_id]
      sample.coord.color = (255, 0, 0)
      selected_samples.append(sample)
    image_circles_marked_np = prompter.add_arrow_overlay_plt(
        im, selected_samples, arm_coord
    )
    output_ims.append(image_circles_marked_np)
    center_mean, _ = prompter.fit(arrow_ids, new_samples)

    if output_ims:
      return (
          output_ims,
          prompter.action_to_coord(center_mean, im, arm_coord).xy,
          selected_samples,
      )
  return [], "Unable to understand query"