File size: 9,008 Bytes
18dd6ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 |
# Copyright (c) Facebook, Inc. and its affiliates.
import numpy as np
from typing import Any, List, Tuple, Union
import torch
from torch.nn import functional as F
class Keypoints:
"""
Stores keypoint **annotation** data. GT Instances have a `gt_keypoints` property
containing the x,y location and visibility flag of each keypoint. This tensor has shape
(N, K, 3) where N is the number of instances and K is the number of keypoints per instance.
The visibility flag follows the COCO format and must be one of three integers:
* v=0: not labeled (in which case x=y=0)
* v=1: labeled but not visible
* v=2: labeled and visible
"""
def __init__(self, keypoints: Union[torch.Tensor, np.ndarray, List[List[float]]]):
"""
Arguments:
keypoints: A Tensor, numpy array, or list of the x, y, and visibility of each keypoint.
The shape should be (N, K, 3) where N is the number of
instances, and K is the number of keypoints per instance.
"""
device = keypoints.device if isinstance(keypoints, torch.Tensor) else torch.device("cpu")
keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=device)
assert keypoints.dim() == 3 and keypoints.shape[2] == 3, keypoints.shape
self.tensor = keypoints
def __len__(self) -> int:
return self.tensor.size(0)
def to(self, *args: Any, **kwargs: Any) -> "Keypoints":
return type(self)(self.tensor.to(*args, **kwargs))
@property
def device(self) -> torch.device:
return self.tensor.device
def to_heatmap(self, boxes: torch.Tensor, heatmap_size: int) -> torch.Tensor:
"""
Convert keypoint annotations to a heatmap of one-hot labels for training,
as described in :paper:`Mask R-CNN`.
Arguments:
boxes: Nx4 tensor, the boxes to draw the keypoints to
Returns:
heatmaps:
A tensor of shape (N, K), each element is integer spatial label
in the range [0, heatmap_size**2 - 1] for each keypoint in the input.
valid:
A tensor of shape (N, K) containing whether each keypoint is in the roi or not.
"""
return _keypoints_to_heatmap(self.tensor, boxes, heatmap_size)
def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Keypoints":
"""
Create a new `Keypoints` by indexing on this `Keypoints`.
The following usage are allowed:
1. `new_kpts = kpts[3]`: return a `Keypoints` which contains only one instance.
2. `new_kpts = kpts[2:10]`: return a slice of key points.
3. `new_kpts = kpts[vector]`, where vector is a torch.ByteTensor
with `length = len(kpts)`. Nonzero elements in the vector will be selected.
Note that the returned Keypoints might share storage with this Keypoints,
subject to Pytorch's indexing semantics.
"""
if isinstance(item, int):
return Keypoints([self.tensor[item]])
return Keypoints(self.tensor[item])
def __repr__(self) -> str:
s = self.__class__.__name__ + "("
s += "num_instances={})".format(len(self.tensor))
return s
@staticmethod
def cat(keypoints_list: List["Keypoints"]) -> "Keypoints":
"""
Concatenates a list of Keypoints into a single Keypoints
Arguments:
keypoints_list (list[Keypoints])
Returns:
Keypoints: the concatenated Keypoints
"""
assert isinstance(keypoints_list, (list, tuple))
assert len(keypoints_list) > 0
assert all(isinstance(keypoints, Keypoints) for keypoints in keypoints_list)
cat_kpts = type(keypoints_list[0])(
torch.cat([kpts.tensor for kpts in keypoints_list], dim=0)
)
return cat_kpts
# TODO make this nicer, this is a direct translation from C2 (but removing the inner loop)
def _keypoints_to_heatmap(
keypoints: torch.Tensor, rois: torch.Tensor, heatmap_size: int
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Encode keypoint locations into a target heatmap for use in SoftmaxWithLoss across space.
Maps keypoints from the half-open interval [x1, x2) on continuous image coordinates to the
closed interval [0, heatmap_size - 1] on discrete image coordinates. We use the
continuous-discrete conversion from Heckbert 1990 ("What is the coordinate of a pixel?"):
d = floor(c) and c = d + 0.5, where d is a discrete coordinate and c is a continuous coordinate.
Arguments:
keypoints: tensor of keypoint locations in of shape (N, K, 3).
rois: Nx4 tensor of rois in xyxy format
heatmap_size: integer side length of square heatmap.
Returns:
heatmaps: A tensor of shape (N, K) containing an integer spatial label
in the range [0, heatmap_size**2 - 1] for each keypoint in the input.
valid: A tensor of shape (N, K) containing whether each keypoint is in
the roi or not.
"""
if rois.numel() == 0:
return rois.new().long(), rois.new().long()
offset_x = rois[:, 0]
offset_y = rois[:, 1]
scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
offset_x = offset_x[:, None]
offset_y = offset_y[:, None]
scale_x = scale_x[:, None]
scale_y = scale_y[:, None]
x = keypoints[..., 0]
y = keypoints[..., 1]
x_boundary_inds = x == rois[:, 2][:, None]
y_boundary_inds = y == rois[:, 3][:, None]
x = (x - offset_x) * scale_x
x = x.floor().long()
y = (y - offset_y) * scale_y
y = y.floor().long()
x[x_boundary_inds] = heatmap_size - 1
y[y_boundary_inds] = heatmap_size - 1
valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
vis = keypoints[..., 2] > 0
valid = (valid_loc & vis).long()
lin_ind = y * heatmap_size + x
heatmaps = lin_ind * valid
return heatmaps, valid
@torch.jit.script_if_tracing
def heatmaps_to_keypoints(maps: torch.Tensor, rois: torch.Tensor) -> torch.Tensor:
"""
Extract predicted keypoint locations from heatmaps.
Args:
maps (Tensor): (#ROIs, #keypoints, POOL_H, POOL_W). The predicted heatmap of logits for
each ROI and each keypoint.
rois (Tensor): (#ROIs, 4). The box of each ROI.
Returns:
Tensor of shape (#ROIs, #keypoints, 4) with the last dimension corresponding to
(x, y, logit, score) for each keypoint.
When converting discrete pixel indices in an NxN image to a continuous keypoint coordinate,
we maintain consistency with :meth:`Keypoints.to_heatmap` by using the conversion from
Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a continuous coordinate.
"""
offset_x = rois[:, 0]
offset_y = rois[:, 1]
widths = (rois[:, 2] - rois[:, 0]).clamp(min=1)
heights = (rois[:, 3] - rois[:, 1]).clamp(min=1)
widths_ceil = widths.ceil()
heights_ceil = heights.ceil()
num_rois, num_keypoints = maps.shape[:2]
xy_preds = maps.new_zeros(rois.shape[0], num_keypoints, 4)
width_corrections = widths / widths_ceil
height_corrections = heights / heights_ceil
keypoints_idx = torch.arange(num_keypoints, device=maps.device)
for i in range(num_rois):
outsize = (int(heights_ceil[i]), int(widths_ceil[i]))
roi_map = F.interpolate(maps[[i]], size=outsize, mode="bicubic", align_corners=False)
# Although semantically equivalent, `reshape` is used instead of `squeeze` due
# to limitation during ONNX export of `squeeze` in scripting mode
roi_map = roi_map.reshape(roi_map.shape[1:]) # keypoints x H x W
# softmax over the spatial region
max_score, _ = roi_map.view(num_keypoints, -1).max(1)
max_score = max_score.view(num_keypoints, 1, 1)
tmp_full_resolution = (roi_map - max_score).exp_()
tmp_pool_resolution = (maps[i] - max_score).exp_()
# Produce scores over the region H x W, but normalize with POOL_H x POOL_W,
# so that the scores of objects of different absolute sizes will be more comparable
roi_map_scores = tmp_full_resolution / tmp_pool_resolution.sum((1, 2), keepdim=True)
w = roi_map.shape[2]
pos = roi_map.view(num_keypoints, -1).argmax(1)
x_int = pos % w
y_int = (pos - x_int) // w
assert (
roi_map_scores[keypoints_idx, y_int, x_int]
== roi_map_scores.view(num_keypoints, -1).max(1)[0]
).all()
x = (x_int.float() + 0.5) * width_corrections[i]
y = (y_int.float() + 0.5) * height_corrections[i]
xy_preds[i, :, 0] = x + offset_x[i]
xy_preds[i, :, 1] = y + offset_y[i]
xy_preds[i, :, 2] = roi_map[keypoints_idx, y_int, x_int]
xy_preds[i, :, 3] = roi_map_scores[keypoints_idx, y_int, x_int]
return xy_preds
|