File size: 4,672 Bytes
5c80958 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
# pylint: disable=line-too-long
"""Utils for visual iterative prompting.
A number of utility functions for VIP.
"""
import copy
import re
import numpy as np
import scipy.spatial.distance as distance
import matplotlib.pyplot as plt
def min_dist(coord, coords):
if not coords:
return np.inf
xys = np.asarray([[coord.xy] for coord in coords])
return np.linalg.norm(xys - np.asarray(coord.xy), axis=-1).min()
def coord_outside_image(coord, image, radius):
(height, image_width, _) = image.shape
x, y = coord.xy
x_outside = x > image_width - 2 * radius or x < 2 * radius
y_outside = y > height - 2 * radius or y < 2 * radius
return x_outside or y_outside
def is_invalid_coord(coord, coords, radius, image):
# invalid if too close to others or outside of the image
pos_overlaps = min_dist(coord, coords) < 1.5 * radius
return pos_overlaps or coord_outside_image(coord, image, radius)
def angle_mag_2_x_y(angle, mag, arm_coord, is_circle=False, radius=40):
x, y = arm_coord
x += int(np.cos(angle) * mag)
y += int(np.sin(angle) * mag)
if is_circle:
x += int(np.cos(angle) * radius * np.sign(mag))
y += int(np.sin(angle) * radius * np.sign(mag))
return x, y
def coord_to_text_coord(coord, arm_coord, radius):
delta_coord = np.asarray(coord.xy) - arm_coord
if np.linalg.norm(delta_coord) == 0:
return arm_coord
return (
int(coord.xy[0] + radius * delta_coord[0] / np.linalg.norm(delta_coord)),
int(coord.xy[1] + radius * delta_coord[1] / np.linalg.norm(delta_coord)))
def prep_aloha_frames(real_frame):
"""Prepare collage of ALOHA view frames."""
markup_frame = copy.deepcopy(real_frame)
top_frame = copy.deepcopy(markup_frame[
:int(markup_frame.shape[0] / 2), :int(markup_frame.shape[1] / 2)])
side_frame = copy.deepcopy(markup_frame[
int(markup_frame.shape[0] / 2):, :int(markup_frame.shape[1] / 2)])
right_frame = copy.deepcopy(markup_frame[
int(markup_frame.shape[0] / 2):, int(markup_frame.shape[1] / 2):])
left_frame = copy.deepcopy(markup_frame[
:int(markup_frame.shape[0] / 2), int(markup_frame.shape[1] / 2):])
markup_frame[int(markup_frame.shape[0] / 2):, :int(markup_frame.shape[1] / 2)] = left_frame
markup_frame[:int(markup_frame.shape[0] / 2), int(markup_frame.shape[1] / 2):] = side_frame
return markup_frame, right_frame, left_frame
def parse_response(response, answer_key='Arrow: ['):
values = []
if answer_key in response:
print('parse_response from answer_key')
arrow_response = response.split(answer_key)[-1].split(']')[0]
for val in map(int, re.findall(r'\d+', arrow_response)):
values.append(val)
else:
print('parse_response for all ints')
for val in map(int, re.findall(r'\d+', response)):
values.append(val)
return values
# TODO(ichter): normalize values by std
def compute_errors(action, true_action, verbose=False):
"""Compute errors between a predicted action and true action."""
l2_error = np.linalg.norm(action - true_action)
cos_sim = 1 - distance.cosine(action, true_action)
l2_xy_error = np.linalg.norm(action[-2:] - true_action[-2:])
cos_xy_sim = 1 - distance.cosine(action[-2:], true_action[-2:])
z_error = np.abs(action[0] - true_action[0])
errors = {'l2': l2_error,
'cos_sim': cos_sim,
'l2_xy_error': l2_xy_error,
'cos_xy_sim': cos_xy_sim,
'z_error': z_error}
if verbose:
print('action: \t', [f'{a:.3f}' for a in action])
print('true_action \t', [f'{a:.3f}' for a in true_action])
print(f'l2: \t\t{l2_error:.3f}')
print(f'l2_xy_error: \t{l2_xy_error:.3f}')
print(f'cos_sim: \t{cos_sim:.3f}')
print(f'cos_xy_sim: \t{cos_xy_sim:.3f}')
print(f'z_error: \t{z_error:.3f}')
return errors
def plot_errors(all_errors, error_types=None):
"""Plot errors across iterations."""
if error_types is None:
error_types = ['l2', 'l2_xy_error', 'z_error', 'cos_sim', 'cos_xy_sim',]
_, axs = plt.subplots(2, 3, figsize=(15, 8))
for i, error_type in enumerate(error_types): # go through each error type
all_iter_errors = {}
for error_by_iter in all_errors: # go through each call
for itr in error_by_iter: # go through each iteration
if itr in all_iter_errors: # add error to the iteration it happened
all_iter_errors[itr].append(error_by_iter[itr][error_type])
else:
all_iter_errors[itr] = [error_by_iter[itr][error_type]]
mean_iter_errors = [np.mean(all_iter_errors[itr]) for itr in all_iter_errors]
axs[i // 3, i % 3].plot(all_iter_errors.keys(), mean_iter_errors)
axs[i // 3, i % 3].set_title(error_type)
plt.show()
|