pivot-iterative-visual-optimization commited on
Commit
f9a62da
·
verified ·
1 Parent(s): 53ef1bb

Upload 5 files

Browse files
Files changed (4) hide show
  1. app.py +1 -1
  2. vip.py +74 -139
  3. vip_runner.py +2 -2
  4. vip_utils.py +21 -29
app.py CHANGED
@@ -49,7 +49,7 @@ def run_vip(
49
  'min': [0, -300.0, -300],
50
  'max': [0, 300, 300],
51
  'action_to_coord': 250,
52
- 'robot': 'meta',
53
  }
54
 
55
  vlm = GPT4V(openai_api_key=openai_api_key)
 
49
  'min': [0, -300.0, -300],
50
  'max': [0, 300, 300],
51
  'action_to_coord': 250,
52
+ 'robot': None,
53
  }
54
 
55
  vlm = GPT4V(openai_api_key=openai_api_key)
vip.py CHANGED
@@ -1,18 +1,6 @@
1
- # pylint: disable=line-too-long
2
  """Visual Iterative Prompting functions.
3
 
4
- Copied from experimental/users/ichter/vip/vip.py
5
-
6
  Code to implement visual iterative prompting, an approach for querying VLMs.
7
- See go/visual-iterative-prompting for more information.
8
-
9
- These are used within Colabs such as:
10
- *
11
- https://colab.corp.google.com/drive/1GnO-1urDCETWo3M3PpQKQ8TqT1Ql_jiS#scrollTo=5dUSoiz6Hplv
12
- *
13
- https://colab.corp.google.com/drive/14AYsa4W68NnsaREFTUX7lTkSxpD5eHCO#scrollTo=qA2A_oTcGTzN
14
- *
15
- https://colab.corp.google.com/drive/11H-WtHNYzBkr_lQpaa4ASeYy0HD29EXe#scrollTo=HapF0UIxdJM6
16
  """
17
 
18
  import copy
@@ -31,9 +19,7 @@ import vip_utils
31
  class SupportedEmbodiments(str, enum.Enum):
32
  """Embodiments supported by VIP."""
33
 
34
- META_MANIPULATION = 'meta_manipulation'
35
- ALOHA_MANIPULATION = 'aloha_manipulation'
36
- META_NAVIGATION = 'meta_navigation'
37
 
38
 
39
  @dataclasses.dataclass()
@@ -74,95 +60,8 @@ class VisualIterativePrompter:
74
 
75
  def action_to_coord(self, action, image, arm_xy, do_project=False):
76
  """Converts candidate action to image coordinate."""
77
- if (self.embodiment == SupportedEmbodiments.META_MANIPULATION or
78
- self.embodiment == SupportedEmbodiments.ALOHA_MANIPULATION):
79
- return self.manipulation_action_to_coord(
80
- action=action, image=image, arm_xy=arm_xy, do_project=do_project
81
- )
82
- elif self.embodiment == SupportedEmbodiments.META_NAVIGATION:
83
- return self.navigation_action_to_coord(
84
- action=action, image=image, center_xy=arm_xy, do_project=do_project
85
- )
86
- else:
87
- raise NotImplementedError('Embodiment not supported.')
88
-
89
- def manipulation_action_to_coord(
90
- self, action, image, arm_xy, do_project=False
91
- ):
92
- """Converts a ZXY or XY action to an image coordinate.
93
-
94
- Conversion is done based on style['focal_offset'] and action_spec['scale'].
95
-
96
- Args:
97
- action: z, y, x action in robot action space
98
- image: image
99
- arm_xy: x, y in image space
100
- do_project: whether or not to project actions sampled outside the image to
101
- the edge of the image
102
-
103
- Returns:
104
- Dict coordinate with image x, y, arrow color, and circle radius.
105
- """
106
- # TODO(tedxiao): Refactor into common utiliy fns, add embodiment specific
107
- # logic.
108
- if self.action_spec['scale'][0] == 0: # no z dimension
109
- norm_action = [(action[d] - self.action_spec['loc'][d]) /
110
- (2 * self.action_spec['scale'][d]) for d in range(1, 3)]
111
- norm_action_y, norm_action_x = norm_action
112
- norm_action_z = 0
113
- else:
114
- norm_action = [(action[d] - self.action_spec['loc'][d]) /
115
- (2 * self.action_spec['scale'][d]) for d in range(3)]
116
- norm_action_z, norm_action_y, norm_action_x = norm_action
117
- focal_length = np.max(
118
- [0.2, # positive focal lengths only
119
- self.style['focal_offset'] / (self.style['focal_offset'] + norm_action_z)])
120
- image_x = arm_xy[0] - (
121
- self.action_spec['action_to_coord'] * norm_action_x * focal_length
122
- )
123
- image_y = arm_xy[1] - (
124
- self.action_spec['action_to_coord'] * norm_action_y * focal_length
125
- )
126
- if vip_utils.coord_outside_image(
127
- coord=Coordinate(xy=(int(image_x), int(image_y))),
128
- image=image,
129
- radius=self.style['radius']) and do_project:
130
- # project the arrow to the edge of the image if too large
131
- height, width, _ = image.shape
132
- max_x = (
133
- width - arm_xy[0] - 2 * self.style['radius']
134
- if norm_action_x < 0
135
- else arm_xy[0] - 2 * self.style['radius']
136
- )
137
- max_y = (
138
- height - arm_xy[1] - 2 * self.style['radius']
139
- if norm_action_y < 0
140
- else arm_xy[1] - 2 * self.style['radius']
141
- )
142
- rescale_ratio = min(np.abs([
143
- max_x / (self.action_spec['action_to_coord'] * norm_action_x),
144
- max_y / (self.action_spec['action_to_coord'] * norm_action_y)]))
145
- image_x = (
146
- arm_xy[0]
147
- - self.action_spec['action_to_coord'] * norm_action_x * rescale_ratio
148
- )
149
- image_y = (
150
- arm_xy[1]
151
- - self.action_spec['action_to_coord'] * norm_action_y * rescale_ratio
152
- )
153
-
154
- # blue is out of the page, red is into the page
155
- red_z = self.style['rgb_scale'] * ((norm_action[0] + 1) / 2)
156
- blue_z = self.style['rgb_scale'] * (1 - (norm_action[0] + 1) / 2)
157
- color_z = np.clip(
158
- (red_z, 0, blue_z),
159
- 0, self.style['rgb_scale'])
160
- radius_z = int(np.clip((0.75 - norm_action_z / 4) * self.style['radius'],
161
- 0.5 * self.style['radius'], self.style['radius']))
162
- return Coordinate(
163
- xy=(int(image_x), int(image_y)),
164
- color=color_z,
165
- radius=radius_z,
166
  )
167
 
168
  def navigation_action_to_coord(
@@ -182,20 +81,26 @@ class VisualIterativePrompter:
182
  Returns:
183
  Dict coordinate with image x, y, arrow color, and circle radius.
184
  """
185
- # TODO(tedxiao): Refactor into common utiliy fns, add embodiment specific
186
- # logic.
187
  if self.action_spec['scale'][0] == 0: # no z dimension
188
- norm_action = [(action[d] - self.action_spec['loc'][d]) /
189
- (2 * self.action_spec['scale'][d]) for d in range(1, 3)]
 
 
 
190
  norm_action_y, norm_action_x = norm_action
191
  norm_action_z = 0
192
  else:
193
- norm_action = [(action[d] - self.action_spec['loc'][d]) /
194
- (2 * self.action_spec['scale'][d]) for d in range(3)]
 
 
 
195
  norm_action_z, norm_action_y, norm_action_x = norm_action
196
- focal_length = np.max(
197
- [0.2, # positive focal lengths only
198
- self.style['focal_offset'] / (self.style['focal_offset'] + norm_action_z)])
 
 
199
  image_x = center_xy[0] - (
200
  self.action_spec['action_to_coord'] * norm_action_x * focal_length
201
  )
@@ -220,9 +125,12 @@ class VisualIterativePrompter:
220
  if norm_action_y < 0
221
  else center_xy[1] - 2 * self.style['radius']
222
  )
223
- rescale_ratio = min(np.abs([
224
- max_x / (self.action_spec['action_to_coord'] * norm_action_x),
225
- max_y / (self.action_spec['action_to_coord'] * norm_action_y)]))
 
 
 
226
  image_x = (
227
  center_xy[0]
228
  - self.action_spec['action_to_coord'] * norm_action_x * rescale_ratio
@@ -282,19 +190,28 @@ class VisualIterativePrompter:
282
  itrs = 0
283
 
284
  # Generate action scaled appropriately.
285
- action = np.clip(np.random.normal(loc, scale),
286
- self.action_spec['min'], self.action_spec['max'])
 
 
 
287
 
288
  # Convert sampled action to image coordinates.
289
  coord = self.action_to_coord(action, image, arm_xy)
290
 
291
  # Resample action if it results in invalid image annotation.
292
  adjusted_scale = np.array(scale)
293
- while ((vip_utils.is_invalid_coord(coord, coords, self.style['radius']*1.5, image)
294
- or vip_utils.coord_outside_image(coord, image, self.style['radius']))
295
- and itrs < max_itrs):
296
- action = np.clip(np.random.normal(loc, adjusted_scale),
297
- self.action_spec['min'], self.action_spec['max'])
 
 
 
 
 
 
298
  coord = self.action_to_coord(action, image, arm_xy)
299
  itrs += 1
300
  # increase sampling range slightly if not finding a good sample
@@ -325,7 +242,7 @@ class VisualIterativePrompter:
325
  samples.append(sample)
326
  return samples
327
 
328
- def add_arrow_overlay_plt(self, image, samples, arm_xy, log_image=False):
329
  """Add arrows and circles to the image.
330
 
331
  Args:
@@ -353,8 +270,13 @@ class VisualIterativePrompter:
353
  cv2.arrowedLine(
354
  overlay, arm_xy, sample.coord.xy, color, self.style['thickness']
355
  )
356
- image = cv2.addWeighted(overlay, self.style['arrow_alpha'],
357
- image, 1 - self.style['arrow_alpha'], 0)
 
 
 
 
 
358
 
359
  overlay = image.copy()
360
  # Add circles.
@@ -369,8 +291,13 @@ class VisualIterativePrompter:
369
  self.style['thickness'] + 1,
370
  )
371
  cv2.circle(overlay, sample.text_coord.xy, radius, white, -1)
372
- image = cv2.addWeighted(overlay, self.style['circle_alpha'],
373
- image, 1 - self.style['circle_alpha'], 0)
 
 
 
 
 
374
 
375
  dpi = plt.rcParams['figure.dpi']
376
  if self.fig_scale_size is None:
@@ -386,12 +313,15 @@ class VisualIterativePrompter:
386
  plt.close()
387
  buf.seek(0)
388
  test_image = cv2.imdecode(
389
- np.frombuffer(buf.getvalue(), dtype=np.uint8), cv2.IMREAD_COLOR)
 
390
  self.fig_scale_size = original_image_width / test_image.shape[1]
391
 
392
  # Add text to figure.
393
- fig_size = (self.fig_scale_size * original_image_width / dpi,
394
- self.fig_scale_size * original_image_height / dpi)
 
 
395
  plt.subplots(1, figsize=fig_size)
396
  plt.imshow(image, cmap='binary')
397
  for sample in samples:
@@ -412,15 +342,13 @@ class VisualIterativePrompter:
412
  buf = io.BytesIO()
413
  plt.savefig(buf, format='png')
414
  plt.close()
415
- image = cv2.imdecode(np.frombuffer(buf.getvalue(), dtype=np.uint8),
416
- cv2.IMREAD_COLOR)
 
417
 
418
  image = cv2.resize(image, (original_image_width, original_image_height))
419
  image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
420
 
421
- # Optionally log images to CNS.
422
- if log_image:
423
- raise NotImplementedError('TODO: log image too CNS')
424
  return image
425
 
426
  def fit(self, values, samples):
@@ -446,7 +374,7 @@ class VisualIterativePrompter:
446
  action = actions[index]
447
  print('action', action)
448
  loc = action
449
- scale = self.action_spec["min_scale"]
450
  else: # fit distribution
451
  selected_actions = []
452
  for value in values:
@@ -454,9 +382,16 @@ class VisualIterativePrompter:
454
  selected_actions.append(actions[idx])
455
  print('selected_actions', selected_actions)
456
 
457
- loc_scale = [scipy.stats.norm.fit([action[d] for action in selected_actions]) for d in range(3)]
 
 
 
458
  loc = [loc_scale[d][0] for d in range(3)]
459
- scale = np.clip([loc_scale[d][1] for d in range(3)], self.action_spec['min_scale'], None)
 
 
 
 
460
  print('loc', loc, '\nscale', scale)
461
 
462
  return loc, scale
 
 
1
  """Visual Iterative Prompting functions.
2
 
 
 
3
  Code to implement visual iterative prompting, an approach for querying VLMs.
 
 
 
 
 
 
 
 
 
4
  """
5
 
6
  import copy
 
19
  class SupportedEmbodiments(str, enum.Enum):
20
  """Embodiments supported by VIP."""
21
 
22
+ HF_DEMO = 'hf_demo'
 
 
23
 
24
 
25
  @dataclasses.dataclass()
 
60
 
61
  def action_to_coord(self, action, image, arm_xy, do_project=False):
62
  """Converts candidate action to image coordinate."""
63
+ return self.navigation_action_to_coord(
64
+ action=action, image=image, center_xy=arm_xy, do_project=do_project
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  )
66
 
67
  def navigation_action_to_coord(
 
81
  Returns:
82
  Dict coordinate with image x, y, arrow color, and circle radius.
83
  """
 
 
84
  if self.action_spec['scale'][0] == 0: # no z dimension
85
+ norm_action = [
86
+ (action[d] - self.action_spec['loc'][d])
87
+ / (2 * self.action_spec['scale'][d])
88
+ for d in range(1, 3)
89
+ ]
90
  norm_action_y, norm_action_x = norm_action
91
  norm_action_z = 0
92
  else:
93
+ norm_action = [
94
+ (action[d] - self.action_spec['loc'][d])
95
+ / (2 * self.action_spec['scale'][d])
96
+ for d in range(3)
97
+ ]
98
  norm_action_z, norm_action_y, norm_action_x = norm_action
99
+ focal_length = np.max([
100
+ 0.2, # positive focal lengths only
101
+ self.style['focal_offset']
102
+ / (self.style['focal_offset'] + norm_action_z),
103
+ ])
104
  image_x = center_xy[0] - (
105
  self.action_spec['action_to_coord'] * norm_action_x * focal_length
106
  )
 
125
  if norm_action_y < 0
126
  else center_xy[1] - 2 * self.style['radius']
127
  )
128
+ rescale_ratio = min(
129
+ np.abs([
130
+ max_x / (self.action_spec['action_to_coord'] * norm_action_x),
131
+ max_y / (self.action_spec['action_to_coord'] * norm_action_y),
132
+ ])
133
+ )
134
  image_x = (
135
  center_xy[0]
136
  - self.action_spec['action_to_coord'] * norm_action_x * rescale_ratio
 
190
  itrs = 0
191
 
192
  # Generate action scaled appropriately.
193
+ action = np.clip(
194
+ np.random.normal(loc, scale),
195
+ self.action_spec['min'],
196
+ self.action_spec['max'],
197
+ )
198
 
199
  # Convert sampled action to image coordinates.
200
  coord = self.action_to_coord(action, image, arm_xy)
201
 
202
  # Resample action if it results in invalid image annotation.
203
  adjusted_scale = np.array(scale)
204
+ while (
205
+ vip_utils.is_invalid_coord(
206
+ coord, coords, self.style['radius'] * 1.5, image
207
+ )
208
+ or vip_utils.coord_outside_image(coord, image, self.style['radius'])
209
+ ) and itrs < max_itrs:
210
+ action = np.clip(
211
+ np.random.normal(loc, adjusted_scale),
212
+ self.action_spec['min'],
213
+ self.action_spec['max'],
214
+ )
215
  coord = self.action_to_coord(action, image, arm_xy)
216
  itrs += 1
217
  # increase sampling range slightly if not finding a good sample
 
242
  samples.append(sample)
243
  return samples
244
 
245
+ def add_arrow_overlay_plt(self, image, samples, arm_xy):
246
  """Add arrows and circles to the image.
247
 
248
  Args:
 
270
  cv2.arrowedLine(
271
  overlay, arm_xy, sample.coord.xy, color, self.style['thickness']
272
  )
273
+ image = cv2.addWeighted(
274
+ overlay,
275
+ self.style['arrow_alpha'],
276
+ image,
277
+ 1 - self.style['arrow_alpha'],
278
+ 0,
279
+ )
280
 
281
  overlay = image.copy()
282
  # Add circles.
 
291
  self.style['thickness'] + 1,
292
  )
293
  cv2.circle(overlay, sample.text_coord.xy, radius, white, -1)
294
+ image = cv2.addWeighted(
295
+ overlay,
296
+ self.style['circle_alpha'],
297
+ image,
298
+ 1 - self.style['circle_alpha'],
299
+ 0,
300
+ )
301
 
302
  dpi = plt.rcParams['figure.dpi']
303
  if self.fig_scale_size is None:
 
313
  plt.close()
314
  buf.seek(0)
315
  test_image = cv2.imdecode(
316
+ np.frombuffer(buf.getvalue(), dtype=np.uint8), cv2.IMREAD_COLOR
317
+ )
318
  self.fig_scale_size = original_image_width / test_image.shape[1]
319
 
320
  # Add text to figure.
321
+ fig_size = (
322
+ self.fig_scale_size * original_image_width / dpi,
323
+ self.fig_scale_size * original_image_height / dpi,
324
+ )
325
  plt.subplots(1, figsize=fig_size)
326
  plt.imshow(image, cmap='binary')
327
  for sample in samples:
 
342
  buf = io.BytesIO()
343
  plt.savefig(buf, format='png')
344
  plt.close()
345
+ image = cv2.imdecode(
346
+ np.frombuffer(buf.getvalue(), dtype=np.uint8), cv2.IMREAD_COLOR
347
+ )
348
 
349
  image = cv2.resize(image, (original_image_width, original_image_height))
350
  image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
351
 
 
 
 
352
  return image
353
 
354
  def fit(self, values, samples):
 
374
  action = actions[index]
375
  print('action', action)
376
  loc = action
377
+ scale = self.action_spec['min_scale']
378
  else: # fit distribution
379
  selected_actions = []
380
  for value in values:
 
382
  selected_actions.append(actions[idx])
383
  print('selected_actions', selected_actions)
384
 
385
+ loc_scale = [
386
+ scipy.stats.norm.fit([action[d] for action in selected_actions])
387
+ for d in range(3)
388
+ ]
389
  loc = [loc_scale[d][0] for d in range(3)]
390
+ scale = np.clip(
391
+ [loc_scale[d][1] for d in range(3)],
392
+ self.action_spec['min_scale'],
393
+ None,
394
+ )
395
  print('loc', loc, '\nscale', scale)
396
 
397
  return loc, scale
vip_runner.py CHANGED
@@ -41,7 +41,7 @@ def extract_json(response, key):
41
  def vip_perform_selection(prompter, vlm, im, desc, arm_coord, samples, top_n):
42
  """Perform one selection pass given samples."""
43
  image_circles_np = prompter.add_arrow_overlay_plt(
44
- image=im, samples=samples, arm_xy=arm_coord, log_image=False
45
  )
46
 
47
  _, encoded_image_circles = cv2.imencode(".png", image_circles_np)
@@ -71,7 +71,7 @@ def vip_runner(
71
  """VIP."""
72
 
73
  prompter = vip.VisualIterativePrompter(
74
- style, action_spec, vip.SupportedEmbodiments.META_NAVIGATION
75
  )
76
 
77
  output_ims = []
 
41
  def vip_perform_selection(prompter, vlm, im, desc, arm_coord, samples, top_n):
42
  """Perform one selection pass given samples."""
43
  image_circles_np = prompter.add_arrow_overlay_plt(
44
+ image=im, samples=samples, arm_xy=arm_coord
45
  )
46
 
47
  _, encoded_image_circles = cv2.imencode(".png", image_circles_np)
 
71
  """VIP."""
72
 
73
  prompter = vip.VisualIterativePrompter(
74
+ style, action_spec, vip.SupportedEmbodiments.HF_DEMO
75
  )
76
 
77
  output_ims = []
vip_utils.py CHANGED
@@ -1,15 +1,13 @@
1
- # pylint: disable=line-too-long
2
  """Utils for visual iterative prompting.
3
 
4
  A number of utility functions for VIP.
5
  """
6
 
7
- import copy
8
  import re
9
 
 
10
  import numpy as np
11
  import scipy.spatial.distance as distance
12
- import matplotlib.pyplot as plt
13
 
14
 
15
  def min_dist(coord, coords):
@@ -49,23 +47,8 @@ def coord_to_text_coord(coord, arm_coord, radius):
49
  return arm_coord
50
  return (
51
  int(coord.xy[0] + radius * delta_coord[0] / np.linalg.norm(delta_coord)),
52
- int(coord.xy[1] + radius * delta_coord[1] / np.linalg.norm(delta_coord)))
53
-
54
-
55
- def prep_aloha_frames(real_frame):
56
- """Prepare collage of ALOHA view frames."""
57
- markup_frame = copy.deepcopy(real_frame)
58
- top_frame = copy.deepcopy(markup_frame[
59
- :int(markup_frame.shape[0] / 2), :int(markup_frame.shape[1] / 2)])
60
- side_frame = copy.deepcopy(markup_frame[
61
- int(markup_frame.shape[0] / 2):, :int(markup_frame.shape[1] / 2)])
62
- right_frame = copy.deepcopy(markup_frame[
63
- int(markup_frame.shape[0] / 2):, int(markup_frame.shape[1] / 2):])
64
- left_frame = copy.deepcopy(markup_frame[
65
- :int(markup_frame.shape[0] / 2), int(markup_frame.shape[1] / 2):])
66
- markup_frame[int(markup_frame.shape[0] / 2):, :int(markup_frame.shape[1] / 2)] = left_frame
67
- markup_frame[:int(markup_frame.shape[0] / 2), int(markup_frame.shape[1] / 2):] = side_frame
68
- return markup_frame, right_frame, left_frame
69
 
70
 
71
  def parse_response(response, answer_key='Arrow: ['):
@@ -82,7 +65,6 @@ def parse_response(response, answer_key='Arrow: ['):
82
  return values
83
 
84
 
85
- # TODO(ichter): normalize values by std
86
  def compute_errors(action, true_action, verbose=False):
87
  """Compute errors between a predicted action and true action."""
88
  l2_error = np.linalg.norm(action - true_action)
@@ -90,11 +72,13 @@ def compute_errors(action, true_action, verbose=False):
90
  l2_xy_error = np.linalg.norm(action[-2:] - true_action[-2:])
91
  cos_xy_sim = 1 - distance.cosine(action[-2:], true_action[-2:])
92
  z_error = np.abs(action[0] - true_action[0])
93
- errors = {'l2': l2_error,
94
- 'cos_sim': cos_sim,
95
- 'l2_xy_error': l2_xy_error,
96
- 'cos_xy_sim': cos_xy_sim,
97
- 'z_error': z_error}
 
 
98
 
99
  if verbose:
100
  print('action: \t', [f'{a:.3f}' for a in action])
@@ -111,19 +95,27 @@ def compute_errors(action, true_action, verbose=False):
111
  def plot_errors(all_errors, error_types=None):
112
  """Plot errors across iterations."""
113
  if error_types is None:
114
- error_types = ['l2', 'l2_xy_error', 'z_error', 'cos_sim', 'cos_xy_sim',]
 
 
 
 
 
 
115
 
116
  _, axs = plt.subplots(2, 3, figsize=(15, 8))
117
  for i, error_type in enumerate(error_types): # go through each error type
118
  all_iter_errors = {}
119
  for error_by_iter in all_errors: # go through each call
120
  for itr in error_by_iter: # go through each iteration
121
- if itr in all_iter_errors: # add error to the iteration it happened
122
  all_iter_errors[itr].append(error_by_iter[itr][error_type])
123
  else:
124
  all_iter_errors[itr] = [error_by_iter[itr][error_type]]
125
 
126
- mean_iter_errors = [np.mean(all_iter_errors[itr]) for itr in all_iter_errors]
 
 
127
 
128
  axs[i // 3, i % 3].plot(all_iter_errors.keys(), mean_iter_errors)
129
  axs[i // 3, i % 3].set_title(error_type)
 
 
1
  """Utils for visual iterative prompting.
2
 
3
  A number of utility functions for VIP.
4
  """
5
 
 
6
  import re
7
 
8
+ import matplotlib.pyplot as plt
9
  import numpy as np
10
  import scipy.spatial.distance as distance
 
11
 
12
 
13
  def min_dist(coord, coords):
 
47
  return arm_coord
48
  return (
49
  int(coord.xy[0] + radius * delta_coord[0] / np.linalg.norm(delta_coord)),
50
+ int(coord.xy[1] + radius * delta_coord[1] / np.linalg.norm(delta_coord)),
51
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
 
54
  def parse_response(response, answer_key='Arrow: ['):
 
65
  return values
66
 
67
 
 
68
  def compute_errors(action, true_action, verbose=False):
69
  """Compute errors between a predicted action and true action."""
70
  l2_error = np.linalg.norm(action - true_action)
 
72
  l2_xy_error = np.linalg.norm(action[-2:] - true_action[-2:])
73
  cos_xy_sim = 1 - distance.cosine(action[-2:], true_action[-2:])
74
  z_error = np.abs(action[0] - true_action[0])
75
+ errors = {
76
+ 'l2': l2_error,
77
+ 'cos_sim': cos_sim,
78
+ 'l2_xy_error': l2_xy_error,
79
+ 'cos_xy_sim': cos_xy_sim,
80
+ 'z_error': z_error,
81
+ }
82
 
83
  if verbose:
84
  print('action: \t', [f'{a:.3f}' for a in action])
 
95
  def plot_errors(all_errors, error_types=None):
96
  """Plot errors across iterations."""
97
  if error_types is None:
98
+ error_types = [
99
+ 'l2',
100
+ 'l2_xy_error',
101
+ 'z_error',
102
+ 'cos_sim',
103
+ 'cos_xy_sim',
104
+ ]
105
 
106
  _, axs = plt.subplots(2, 3, figsize=(15, 8))
107
  for i, error_type in enumerate(error_types): # go through each error type
108
  all_iter_errors = {}
109
  for error_by_iter in all_errors: # go through each call
110
  for itr in error_by_iter: # go through each iteration
111
+ if itr in all_iter_errors: # add error to the iteration it happened
112
  all_iter_errors[itr].append(error_by_iter[itr][error_type])
113
  else:
114
  all_iter_errors[itr] = [error_by_iter[itr][error_type]]
115
 
116
+ mean_iter_errors = [
117
+ np.mean(all_iter_errors[itr]) for itr in all_iter_errors
118
+ ]
119
 
120
  axs[i // 3, i % 3].plot(all_iter_errors.keys(), mean_iter_errors)
121
  axs[i // 3, i % 3].set_title(error_type)