pivot-demo / vip_utils.py
pivot-iterative-visual-optimization's picture
Upload 6 files
5c80958 verified
raw
history blame
4.67 kB
# pylint: disable=line-too-long
"""Utils for visual iterative prompting.
A number of utility functions for VIP.
"""
import copy
import re
import numpy as np
import scipy.spatial.distance as distance
import matplotlib.pyplot as plt
def min_dist(coord, coords):
if not coords:
return np.inf
xys = np.asarray([[coord.xy] for coord in coords])
return np.linalg.norm(xys - np.asarray(coord.xy), axis=-1).min()
def coord_outside_image(coord, image, radius):
(height, image_width, _) = image.shape
x, y = coord.xy
x_outside = x > image_width - 2 * radius or x < 2 * radius
y_outside = y > height - 2 * radius or y < 2 * radius
return x_outside or y_outside
def is_invalid_coord(coord, coords, radius, image):
# invalid if too close to others or outside of the image
pos_overlaps = min_dist(coord, coords) < 1.5 * radius
return pos_overlaps or coord_outside_image(coord, image, radius)
def angle_mag_2_x_y(angle, mag, arm_coord, is_circle=False, radius=40):
x, y = arm_coord
x += int(np.cos(angle) * mag)
y += int(np.sin(angle) * mag)
if is_circle:
x += int(np.cos(angle) * radius * np.sign(mag))
y += int(np.sin(angle) * radius * np.sign(mag))
return x, y
def coord_to_text_coord(coord, arm_coord, radius):
delta_coord = np.asarray(coord.xy) - arm_coord
if np.linalg.norm(delta_coord) == 0:
return arm_coord
return (
int(coord.xy[0] + radius * delta_coord[0] / np.linalg.norm(delta_coord)),
int(coord.xy[1] + radius * delta_coord[1] / np.linalg.norm(delta_coord)))
def prep_aloha_frames(real_frame):
"""Prepare collage of ALOHA view frames."""
markup_frame = copy.deepcopy(real_frame)
top_frame = copy.deepcopy(markup_frame[
:int(markup_frame.shape[0] / 2), :int(markup_frame.shape[1] / 2)])
side_frame = copy.deepcopy(markup_frame[
int(markup_frame.shape[0] / 2):, :int(markup_frame.shape[1] / 2)])
right_frame = copy.deepcopy(markup_frame[
int(markup_frame.shape[0] / 2):, int(markup_frame.shape[1] / 2):])
left_frame = copy.deepcopy(markup_frame[
:int(markup_frame.shape[0] / 2), int(markup_frame.shape[1] / 2):])
markup_frame[int(markup_frame.shape[0] / 2):, :int(markup_frame.shape[1] / 2)] = left_frame
markup_frame[:int(markup_frame.shape[0] / 2), int(markup_frame.shape[1] / 2):] = side_frame
return markup_frame, right_frame, left_frame
def parse_response(response, answer_key='Arrow: ['):
values = []
if answer_key in response:
print('parse_response from answer_key')
arrow_response = response.split(answer_key)[-1].split(']')[0]
for val in map(int, re.findall(r'\d+', arrow_response)):
values.append(val)
else:
print('parse_response for all ints')
for val in map(int, re.findall(r'\d+', response)):
values.append(val)
return values
# TODO(ichter): normalize values by std
def compute_errors(action, true_action, verbose=False):
"""Compute errors between a predicted action and true action."""
l2_error = np.linalg.norm(action - true_action)
cos_sim = 1 - distance.cosine(action, true_action)
l2_xy_error = np.linalg.norm(action[-2:] - true_action[-2:])
cos_xy_sim = 1 - distance.cosine(action[-2:], true_action[-2:])
z_error = np.abs(action[0] - true_action[0])
errors = {'l2': l2_error,
'cos_sim': cos_sim,
'l2_xy_error': l2_xy_error,
'cos_xy_sim': cos_xy_sim,
'z_error': z_error}
if verbose:
print('action: \t', [f'{a:.3f}' for a in action])
print('true_action \t', [f'{a:.3f}' for a in true_action])
print(f'l2: \t\t{l2_error:.3f}')
print(f'l2_xy_error: \t{l2_xy_error:.3f}')
print(f'cos_sim: \t{cos_sim:.3f}')
print(f'cos_xy_sim: \t{cos_xy_sim:.3f}')
print(f'z_error: \t{z_error:.3f}')
return errors
def plot_errors(all_errors, error_types=None):
"""Plot errors across iterations."""
if error_types is None:
error_types = ['l2', 'l2_xy_error', 'z_error', 'cos_sim', 'cos_xy_sim',]
_, axs = plt.subplots(2, 3, figsize=(15, 8))
for i, error_type in enumerate(error_types): # go through each error type
all_iter_errors = {}
for error_by_iter in all_errors: # go through each call
for itr in error_by_iter: # go through each iteration
if itr in all_iter_errors: # add error to the iteration it happened
all_iter_errors[itr].append(error_by_iter[itr][error_type])
else:
all_iter_errors[itr] = [error_by_iter[itr][error_type]]
mean_iter_errors = [np.mean(all_iter_errors[itr]) for itr in all_iter_errors]
axs[i // 3, i % 3].plot(all_iter_errors.keys(), mean_iter_errors)
axs[i // 3, i % 3].set_title(error_type)
plt.show()