pivot-demo / vip.py
pivot-iterative-visual-optimization's picture
Upload 6 files
5c80958 verified
raw
history blame
15.9 kB
# pylint: disable=line-too-long
"""Visual Iterative Prompting functions.
Copied from experimental/users/ichter/vip/vip.py
Code to implement visual iterative prompting, an approach for querying VLMs.
See go/visual-iterative-prompting for more information.
These are used within Colabs such as:
*
https://colab.corp.google.com/drive/1GnO-1urDCETWo3M3PpQKQ8TqT1Ql_jiS#scrollTo=5dUSoiz6Hplv
*
https://colab.corp.google.com/drive/14AYsa4W68NnsaREFTUX7lTkSxpD5eHCO#scrollTo=qA2A_oTcGTzN
*
https://colab.corp.google.com/drive/11H-WtHNYzBkr_lQpaa4ASeYy0HD29EXe#scrollTo=HapF0UIxdJM6
"""
import copy
import dataclasses
import enum
import io
from typing import Optional, Tuple
import cv2
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats
import vip_utils
@enum.unique
class SupportedEmbodiments(str, enum.Enum):
"""Embodiments supported by VIP."""
META_MANIPULATION = 'meta_manipulation'
ALOHA_MANIPULATION = 'aloha_manipulation'
META_NAVIGATION = 'meta_navigation'
@dataclasses.dataclass()
class Coordinate:
"""Coordinate with necessary information for visualizing annotation."""
# 2D image coordinates for the target annotation
xy: Tuple[int, int]
# Color and style of the coord.
color: Optional[float] = None
radius: Optional[int] = None
@dataclasses.dataclass()
class Sample:
"""Single Sample mapping actions to Coordinates."""
# 2D or 3D action
action: np.ndarray
# Coordinates for the main annotation
coord: Coordinate
# Coordinates for the text label
text_coord: Coordinate
# Label to display in the text label
label: str
class VisualIterativePrompter:
"""Visual Iterative Prompting class."""
def __init__(self, style, action_spec, embodiment):
self.embodiment = embodiment
self.style = style
self.action_spec = action_spec
self.fig_scale_size = None
# image preparer
# robot_to_image_canonical_coords
def action_to_coord(self, action, image, arm_xy, do_project=False):
"""Converts candidate action to image coordinate."""
if (self.embodiment == SupportedEmbodiments.META_MANIPULATION or
self.embodiment == SupportedEmbodiments.ALOHA_MANIPULATION):
return self.manipulation_action_to_coord(
action=action, image=image, arm_xy=arm_xy, do_project=do_project
)
elif self.embodiment == SupportedEmbodiments.META_NAVIGATION:
return self.navigation_action_to_coord(
action=action, image=image, center_xy=arm_xy, do_project=do_project
)
else:
raise NotImplementedError('Embodiment not supported.')
def manipulation_action_to_coord(
self, action, image, arm_xy, do_project=False
):
"""Converts a ZXY or XY action to an image coordinate.
Conversion is done based on style['focal_offset'] and action_spec['scale'].
Args:
action: z, y, x action in robot action space
image: image
arm_xy: x, y in image space
do_project: whether or not to project actions sampled outside the image to
the edge of the image
Returns:
Dict coordinate with image x, y, arrow color, and circle radius.
"""
# TODO(tedxiao): Refactor into common utiliy fns, add embodiment specific
# logic.
if self.action_spec['scale'][0] == 0: # no z dimension
norm_action = [(action[d] - self.action_spec['loc'][d]) /
(2 * self.action_spec['scale'][d]) for d in range(1, 3)]
norm_action_y, norm_action_x = norm_action
norm_action_z = 0
else:
norm_action = [(action[d] - self.action_spec['loc'][d]) /
(2 * self.action_spec['scale'][d]) for d in range(3)]
norm_action_z, norm_action_y, norm_action_x = norm_action
focal_length = np.max(
[0.2, # positive focal lengths only
self.style['focal_offset'] / (self.style['focal_offset'] + norm_action_z)])
image_x = arm_xy[0] - (
self.action_spec['action_to_coord'] * norm_action_x * focal_length
)
image_y = arm_xy[1] - (
self.action_spec['action_to_coord'] * norm_action_y * focal_length
)
if vip_utils.coord_outside_image(
coord=Coordinate(xy=(int(image_x), int(image_y))),
image=image,
radius=self.style['radius']) and do_project:
# project the arrow to the edge of the image if too large
height, width, _ = image.shape
max_x = (
width - arm_xy[0] - 2 * self.style['radius']
if norm_action_x < 0
else arm_xy[0] - 2 * self.style['radius']
)
max_y = (
height - arm_xy[1] - 2 * self.style['radius']
if norm_action_y < 0
else arm_xy[1] - 2 * self.style['radius']
)
rescale_ratio = min(np.abs([
max_x / (self.action_spec['action_to_coord'] * norm_action_x),
max_y / (self.action_spec['action_to_coord'] * norm_action_y)]))
image_x = (
arm_xy[0]
- self.action_spec['action_to_coord'] * norm_action_x * rescale_ratio
)
image_y = (
arm_xy[1]
- self.action_spec['action_to_coord'] * norm_action_y * rescale_ratio
)
# blue is out of the page, red is into the page
red_z = self.style['rgb_scale'] * ((norm_action[0] + 1) / 2)
blue_z = self.style['rgb_scale'] * (1 - (norm_action[0] + 1) / 2)
color_z = np.clip(
(red_z, 0, blue_z),
0, self.style['rgb_scale'])
radius_z = int(np.clip((0.75 - norm_action_z / 4) * self.style['radius'],
0.5 * self.style['radius'], self.style['radius']))
return Coordinate(
xy=(int(image_x), int(image_y)),
color=color_z,
radius=radius_z,
)
def navigation_action_to_coord(
self, action, image, center_xy, do_project=False
):
"""Converts a ZXY or XY action to an image coordinate.
Conversion is done based on style['focal_offset'] and action_spec['scale'].
Args:
action: z, y, x action in robot action space
image: image
center_xy: x, y in image space
do_project: whether or not to project actions sampled outside the image to
the edge of the image
Returns:
Dict coordinate with image x, y, arrow color, and circle radius.
"""
# TODO(tedxiao): Refactor into common utiliy fns, add embodiment specific
# logic.
if self.action_spec['scale'][0] == 0: # no z dimension
norm_action = [(action[d] - self.action_spec['loc'][d]) /
(2 * self.action_spec['scale'][d]) for d in range(1, 3)]
norm_action_y, norm_action_x = norm_action
norm_action_z = 0
else:
norm_action = [(action[d] - self.action_spec['loc'][d]) /
(2 * self.action_spec['scale'][d]) for d in range(3)]
norm_action_z, norm_action_y, norm_action_x = norm_action
focal_length = np.max(
[0.2, # positive focal lengths only
self.style['focal_offset'] / (self.style['focal_offset'] + norm_action_z)])
image_x = center_xy[0] - (
self.action_spec['action_to_coord'] * norm_action_x * focal_length
)
image_y = center_xy[1] - (
self.action_spec['action_to_coord'] * norm_action_y * focal_length
)
if (
vip_utils.coord_outside_image(
Coordinate(xy=(image_x, image_y)), image, self.style['radius']
)
and do_project
):
# project the arrow to the edge of the image if too large
height, width, _ = image.shape
max_x = (
width - center_xy[0] - 2 * self.style['radius']
if norm_action_x < 0
else center_xy[0] - 2 * self.style['radius']
)
max_y = (
height - center_xy[1] - 2 * self.style['radius']
if norm_action_y < 0
else center_xy[1] - 2 * self.style['radius']
)
rescale_ratio = min(np.abs([
max_x / (self.action_spec['action_to_coord'] * norm_action_x),
max_y / (self.action_spec['action_to_coord'] * norm_action_y)]))
image_x = (
center_xy[0]
- self.action_spec['action_to_coord'] * norm_action_x * rescale_ratio
)
image_y = (
center_xy[1]
- self.action_spec['action_to_coord'] * norm_action_y * rescale_ratio
)
return Coordinate(
xy=(int(image_x), int(image_y)),
color=0.1 * self.style['rgb_scale'],
radius=int(self.style['radius']),
)
def sample_actions(
self, image, arm_xy, loc, scale, true_action=None, max_itrs=1000
):
"""Sample actions from distribution.
Args:
image: image
arm_xy: x, y in image space of arm
loc: action distribution mean to sample from
scale: action distribution variance to sample from
true_action: action taken in demonstration if available
max_itrs: number of tries to get a valid sample
Returns:
samples: Samples with associated actions, coords, text_coords, labels.
"""
image = copy.deepcopy(image)
samples = []
actions = []
coords = []
text_coords = []
labels = []
# Keep track of oracle action if available.
true_label = None
if true_action is not None:
actions.append(true_action)
coord = self.action_to_coord(true_action, image, arm_xy)
coords.append(coord)
text_coords.append(
vip_utils.coord_to_text_coord(coords[-1], arm_xy, coord.radius)
)
true_label = np.random.randint(self.style['num_samples'])
# labels.append(str(true_label) + '*')
labels.append(str(true_label))
# Generate all action samples.
for i in range(self.style['num_samples']):
if i == true_label:
continue
itrs = 0
# Generate action scaled appropriately.
action = np.clip(np.random.normal(loc, scale),
self.action_spec['min'], self.action_spec['max'])
# Convert sampled action to image coordinates.
coord = self.action_to_coord(action, image, arm_xy)
# Resample action if it results in invalid image annotation.
adjusted_scale = np.array(scale)
while ((vip_utils.is_invalid_coord(coord, coords, self.style['radius']*1.5, image)
or vip_utils.coord_outside_image(coord, image, self.style['radius']))
and itrs < max_itrs):
action = np.clip(np.random.normal(loc, adjusted_scale),
self.action_spec['min'], self.action_spec['max'])
coord = self.action_to_coord(action, image, arm_xy)
itrs += 1
# increase sampling range slightly if not finding a good sample
adjusted_scale *= 1.1
if itrs == max_itrs:
# If the final iteration results in invalid annotation, just clip
# to edge of image.
coord = self.action_to_coord(action, image, arm_xy, do_project=True)
# Compute image coordinates of text labels.
radius = coord.radius
text_coord = Coordinate(
xy=vip_utils.coord_to_text_coord(coord, arm_xy, radius)
)
actions.append(action)
coords.append(coord)
text_coords.append(text_coord)
labels.append(str(i))
for i in range(len(actions)):
sample = Sample(
action=actions[i],
coord=coords[i],
text_coord=text_coords[i],
label=str(i),
)
samples.append(sample)
return samples
def add_arrow_overlay_plt(self, image, samples, arm_xy, log_image=False):
"""Add arrows and circles to the image.
Args:
image: image
samples: Samples to visualize.
arm_xy: x, y image coordinates for EEF center.
log_image: Boolean for whether to save to CNS.
Returns:
image: image with visual prompts.
"""
# Add transparent arrows and circles
overlay = image.copy()
(original_image_height, original_image_width, _) = image.shape
white = (
self.style['rgb_scale'],
self.style['rgb_scale'],
self.style['rgb_scale'],
)
# Add arrows.
for sample in samples:
color = sample.coord.color
cv2.arrowedLine(
overlay, arm_xy, sample.coord.xy, color, self.style['thickness']
)
image = cv2.addWeighted(overlay, self.style['arrow_alpha'],
image, 1 - self.style['arrow_alpha'], 0)
overlay = image.copy()
# Add circles.
for sample in samples:
color = sample.coord.color
radius = sample.coord.radius
cv2.circle(
overlay,
sample.text_coord.xy,
radius,
color,
self.style['thickness'] + 1,
)
cv2.circle(overlay, sample.text_coord.xy, radius, white, -1)
image = cv2.addWeighted(overlay, self.style['circle_alpha'],
image, 1 - self.style['circle_alpha'], 0)
dpi = plt.rcParams['figure.dpi']
if self.fig_scale_size is None:
# test saving a figure to decide size for text figure
fig_size = (original_image_width / dpi, original_image_height / dpi)
plt.subplots(1, figsize=fig_size)
plt.imshow(image, cmap='binary')
plt.axis('off')
fig = plt.gcf()
fig.tight_layout(pad=0)
buf = io.BytesIO()
plt.savefig(buf, format='png')
plt.close()
buf.seek(0)
test_image = cv2.imdecode(
np.frombuffer(buf.getvalue(), dtype=np.uint8), cv2.IMREAD_COLOR)
self.fig_scale_size = original_image_width / test_image.shape[1]
# Add text to figure.
fig_size = (self.fig_scale_size * original_image_width / dpi,
self.fig_scale_size * original_image_height / dpi)
plt.subplots(1, figsize=fig_size)
plt.imshow(image, cmap='binary')
for sample in samples:
plt.text(
sample.text_coord.xy[0],
sample.text_coord.xy[1],
sample.label,
ha='center',
va='center',
color='k',
fontsize=self.style['fontsize'],
)
# Compile image.
plt.axis('off')
fig = plt.gcf()
fig.tight_layout(pad=0)
buf = io.BytesIO()
plt.savefig(buf, format='png')
plt.close()
image = cv2.imdecode(np.frombuffer(buf.getvalue(), dtype=np.uint8),
cv2.IMREAD_COLOR)
image = cv2.resize(image, (original_image_width, original_image_height))
image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
# Optionally log images to CNS.
if log_image:
raise NotImplementedError('TODO: log image too CNS')
return image
def fit(self, values, samples):
"""Fit a loc and scale to selected actions.
Args:
values: list of selected labels
samples: list of all Samples
Returns:
loc: mean of selected distribution
scale: variance of selected distribution
"""
actions = [sample.action for sample in samples]
labels = [sample.label for sample in samples]
if not values: # revert to initial distribution
print('GPT failed to return integer arrows')
loc = self.action_spec['loc']
scale = self.action_spec['scale']
elif len(values) == 1: # single response, add a distribution over it
index = np.where([label == str(values[-1]) for label in labels])[0][0]
action = actions[index]
print('action', action)
loc = action
scale = self.action_spec["min_scale"]
else: # fit distribution
selected_actions = []
for value in values:
idx = np.where([label == str(value) for label in labels])[0][0]
selected_actions.append(actions[idx])
print('selected_actions', selected_actions)
loc_scale = [scipy.stats.norm.fit([action[d] for action in selected_actions]) for d in range(3)]
loc = [loc_scale[d][0] for d in range(3)]
scale = np.clip([loc_scale[d][1] for d in range(3)], self.action_spec['min_scale'], None)
print('loc', loc, '\nscale', scale)
return loc, scale