File size: 4,672 Bytes
5c80958
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# pylint: disable=line-too-long
"""Utils for visual iterative prompting.

A number of utility functions for VIP.
"""

import copy
import re

import numpy as np
import scipy.spatial.distance as distance
import matplotlib.pyplot as plt


def min_dist(coord, coords):
  if not coords:
    return np.inf
  xys = np.asarray([[coord.xy] for coord in coords])
  return np.linalg.norm(xys - np.asarray(coord.xy), axis=-1).min()


def coord_outside_image(coord, image, radius):
  (height, image_width, _) = image.shape
  x, y = coord.xy
  x_outside = x > image_width - 2 * radius or x < 2 * radius
  y_outside = y > height - 2 * radius or y < 2 * radius
  return x_outside or y_outside


def is_invalid_coord(coord, coords, radius, image):
  # invalid if too close to others or outside of the image
  pos_overlaps = min_dist(coord, coords) < 1.5 * radius
  return pos_overlaps or coord_outside_image(coord, image, radius)


def angle_mag_2_x_y(angle, mag, arm_coord, is_circle=False, radius=40):
  x, y = arm_coord
  x += int(np.cos(angle) * mag)
  y += int(np.sin(angle) * mag)
  if is_circle:
    x += int(np.cos(angle) * radius * np.sign(mag))
    y += int(np.sin(angle) * radius * np.sign(mag))
  return x, y


def coord_to_text_coord(coord, arm_coord, radius):
  delta_coord = np.asarray(coord.xy) - arm_coord
  if np.linalg.norm(delta_coord) == 0:
    return arm_coord
  return (
      int(coord.xy[0] + radius * delta_coord[0] / np.linalg.norm(delta_coord)),
      int(coord.xy[1] + radius * delta_coord[1] / np.linalg.norm(delta_coord)))


def prep_aloha_frames(real_frame):
  """Prepare collage of ALOHA view frames."""
  markup_frame = copy.deepcopy(real_frame)
  top_frame = copy.deepcopy(markup_frame[
      :int(markup_frame.shape[0] / 2), :int(markup_frame.shape[1] / 2)])
  side_frame = copy.deepcopy(markup_frame[
      int(markup_frame.shape[0] / 2):, :int(markup_frame.shape[1] / 2)])
  right_frame = copy.deepcopy(markup_frame[
      int(markup_frame.shape[0] / 2):, int(markup_frame.shape[1] / 2):])
  left_frame = copy.deepcopy(markup_frame[
      :int(markup_frame.shape[0] / 2), int(markup_frame.shape[1] / 2):])
  markup_frame[int(markup_frame.shape[0] / 2):, :int(markup_frame.shape[1] / 2)] = left_frame
  markup_frame[:int(markup_frame.shape[0] / 2), int(markup_frame.shape[1] / 2):] = side_frame
  return markup_frame, right_frame, left_frame


def parse_response(response, answer_key='Arrow: ['):
  values = []
  if answer_key in response:
    print('parse_response from answer_key')
    arrow_response = response.split(answer_key)[-1].split(']')[0]
    for val in map(int, re.findall(r'\d+', arrow_response)):
      values.append(val)
  else:
    print('parse_response for all ints')
    for val in map(int, re.findall(r'\d+', response)):
      values.append(val)
  return values


# TODO(ichter): normalize values by std
def compute_errors(action, true_action, verbose=False):
  """Compute errors between a predicted action and true action."""
  l2_error = np.linalg.norm(action - true_action)
  cos_sim = 1 - distance.cosine(action, true_action)
  l2_xy_error = np.linalg.norm(action[-2:] - true_action[-2:])
  cos_xy_sim = 1 - distance.cosine(action[-2:], true_action[-2:])
  z_error = np.abs(action[0] - true_action[0])
  errors = {'l2': l2_error,
            'cos_sim': cos_sim,
            'l2_xy_error': l2_xy_error,
            'cos_xy_sim': cos_xy_sim,
            'z_error': z_error}

  if verbose:
    print('action: \t', [f'{a:.3f}' for a in action])
    print('true_action \t', [f'{a:.3f}' for a in true_action])
    print(f'l2: \t\t{l2_error:.3f}')
    print(f'l2_xy_error: \t{l2_xy_error:.3f}')
    print(f'cos_sim: \t{cos_sim:.3f}')
    print(f'cos_xy_sim: \t{cos_xy_sim:.3f}')
    print(f'z_error: \t{z_error:.3f}')

  return errors


def plot_errors(all_errors, error_types=None):
  """Plot errors across iterations."""
  if error_types is None:
    error_types = ['l2', 'l2_xy_error', 'z_error', 'cos_sim', 'cos_xy_sim',]

  _, axs = plt.subplots(2, 3, figsize=(15, 8))
  for i, error_type in enumerate(error_types):  # go through each error type
    all_iter_errors = {}
    for error_by_iter in all_errors:  # go through each call
      for itr in error_by_iter:  # go through each iteration
        if itr in all_iter_errors: # add error to the iteration it happened
          all_iter_errors[itr].append(error_by_iter[itr][error_type])
        else:
          all_iter_errors[itr] = [error_by_iter[itr][error_type]]

    mean_iter_errors = [np.mean(all_iter_errors[itr]) for itr in all_iter_errors]

    axs[i // 3, i % 3].plot(all_iter_errors.keys(), mean_iter_errors)
    axs[i // 3, i % 3].set_title(error_type)
  plt.show()