File size: 9,724 Bytes
3e06e1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence, Union

import torch
from torch import Tensor

from mmdet.registry import TASK_UTILS
from mmdet.structures.bbox import BaseBoxes, HorizontalBoxes, get_box_tensor
from .base_bbox_coder import BaseBBoxCoder


@TASK_UTILS.register_module()
class TBLRBBoxCoder(BaseBBoxCoder):
    """TBLR BBox coder.

    Following the practice in `FSAF <https://arxiv.org/abs/1903.00621>`_,
    this coder encodes gt bboxes (x1, y1, x2, y2) into (top, bottom, left,
    right) and decode it back to the original.

    Args:
        normalizer (list | float): Normalization factor to be
          divided with when coding the coordinates. If it is a list, it should
          have length of 4 indicating normalization factor in tblr dims.
          Otherwise it is a unified float factor for all dims. Default: 4.0
        clip_border (bool, optional): Whether clip the objects outside the
            border of the image. Defaults to True.
    """

    def __init__(self,
                 normalizer: Union[Sequence[float], float] = 4.0,
                 clip_border: bool = True,
                 **kwargs) -> None:
        super().__init__(**kwargs)
        self.normalizer = normalizer
        self.clip_border = clip_border

    def encode(self, bboxes: Union[Tensor, BaseBoxes],
               gt_bboxes: Union[Tensor, BaseBoxes]) -> Tensor:
        """Get box regression transformation deltas that can be used to
        transform the ``bboxes`` into the ``gt_bboxes`` in the (top, left,
        bottom, right) order.

        Args:
            bboxes (torch.Tensor or :obj:`BaseBoxes`): source boxes,
                e.g., object proposals.
            gt_bboxes (torch.Tensor or :obj:`BaseBoxes`): target of the
                transformation, e.g., ground truth boxes.

        Returns:
            torch.Tensor: Box transformation deltas
        """
        bboxes = get_box_tensor(bboxes)
        gt_bboxes = get_box_tensor(gt_bboxes)
        assert bboxes.size(0) == gt_bboxes.size(0)
        assert bboxes.size(-1) == gt_bboxes.size(-1) == 4
        encoded_bboxes = bboxes2tblr(
            bboxes, gt_bboxes, normalizer=self.normalizer)
        return encoded_bboxes

    def decode(
        self,
        bboxes: Union[Tensor, BaseBoxes],
        pred_bboxes: Tensor,
        max_shape: Optional[Union[Sequence[int], Tensor,
                                  Sequence[Sequence[int]]]] = None
    ) -> Union[Tensor, BaseBoxes]:
        """Apply transformation `pred_bboxes` to `boxes`.

        Args:
            bboxes (torch.Tensor or :obj:`BaseBoxes`): Basic boxes.Shape
                (B, N, 4) or (N, 4)
            pred_bboxes (torch.Tensor): Encoded boxes with shape
               (B, N, 4) or (N, 4)
            max_shape (Sequence[int] or torch.Tensor or Sequence[
               Sequence[int]],optional): Maximum bounds for boxes, specifies
               (H, W, C) or (H, W). If bboxes shape is (B, N, 4), then
               the max_shape should be a Sequence[Sequence[int]]
               and the length of max_shape should also be B.

        Returns:
            Union[torch.Tensor, :obj:`BaseBoxes`]: Decoded boxes.
        """
        bboxes = get_box_tensor(bboxes)
        decoded_bboxes = tblr2bboxes(
            bboxes,
            pred_bboxes,
            normalizer=self.normalizer,
            max_shape=max_shape,
            clip_border=self.clip_border)

        if self.use_box_type:
            decoded_bboxes = HorizontalBoxes(decoded_bboxes)
        return decoded_bboxes


def bboxes2tblr(priors: Tensor,
                gts: Tensor,
                normalizer: Union[Sequence[float], float] = 4.0,
                normalize_by_wh: bool = True) -> Tensor:
    """Encode ground truth boxes to tblr coordinate.

    It first convert the gt coordinate to tblr format,
     (top, bottom, left, right), relative to prior box centers.
     The tblr coordinate may be normalized by the side length of prior bboxes
     if `normalize_by_wh` is specified as True, and it is then normalized by
     the `normalizer` factor.

    Args:
        priors (Tensor): Prior boxes in point form
            Shape: (num_proposals,4).
        gts (Tensor): Coords of ground truth for each prior in point-form
            Shape: (num_proposals, 4).
        normalizer (Sequence[float] | float): normalization parameter of
            encoded boxes. If it is a list, it has to have length = 4.
            Default: 4.0
        normalize_by_wh (bool): Whether to normalize tblr coordinate by the
            side length (wh) of prior bboxes.

    Return:
        encoded boxes (Tensor), Shape: (num_proposals, 4)
    """

    # dist b/t match center and prior's center
    if not isinstance(normalizer, float):
        normalizer = torch.tensor(normalizer, device=priors.device)
        assert len(normalizer) == 4, 'Normalizer must have length = 4'
    assert priors.size(0) == gts.size(0)
    prior_centers = (priors[:, 0:2] + priors[:, 2:4]) / 2
    xmin, ymin, xmax, ymax = gts.split(1, dim=1)
    top = prior_centers[:, 1].unsqueeze(1) - ymin
    bottom = ymax - prior_centers[:, 1].unsqueeze(1)
    left = prior_centers[:, 0].unsqueeze(1) - xmin
    right = xmax - prior_centers[:, 0].unsqueeze(1)
    loc = torch.cat((top, bottom, left, right), dim=1)
    if normalize_by_wh:
        # Normalize tblr by anchor width and height
        wh = priors[:, 2:4] - priors[:, 0:2]
        w, h = torch.split(wh, 1, dim=1)
        loc[:, :2] /= h  # tb is normalized by h
        loc[:, 2:] /= w  # lr is normalized by w
    # Normalize tblr by the given normalization factor
    return loc / normalizer


def tblr2bboxes(priors: Tensor,
                tblr: Tensor,
                normalizer: Union[Sequence[float], float] = 4.0,
                normalize_by_wh: bool = True,
                max_shape: Optional[Union[Sequence[int], Tensor,
                                          Sequence[Sequence[int]]]] = None,
                clip_border: bool = True) -> Tensor:
    """Decode tblr outputs to prediction boxes.

    The process includes 3 steps: 1) De-normalize tblr coordinates by
    multiplying it with `normalizer`; 2) De-normalize tblr coordinates by the
    prior bbox width and height if `normalize_by_wh` is `True`; 3) Convert
    tblr (top, bottom, left, right) pair relative to the center of priors back
    to (xmin, ymin, xmax, ymax) coordinate.

    Args:
        priors (Tensor): Prior boxes in point form (x0, y0, x1, y1)
          Shape: (N,4) or (B, N, 4).
        tblr (Tensor): Coords of network output in tblr form
          Shape: (N, 4) or (B, N, 4).
        normalizer (Sequence[float] | float): Normalization parameter of
          encoded boxes. By list, it represents the normalization factors at
          tblr dims. By float, it is the unified normalization factor at all
          dims. Default: 4.0
        normalize_by_wh (bool): Whether the tblr coordinates have been
          normalized by the side length (wh) of prior bboxes.
        max_shape (Sequence[int] or torch.Tensor or Sequence[
            Sequence[int]],optional): Maximum bounds for boxes, specifies
            (H, W, C) or (H, W). If priors shape is (B, N, 4), then
            the max_shape should be a Sequence[Sequence[int]]
            and the length of max_shape should also be B.
        clip_border (bool, optional): Whether clip the objects outside the
            border of the image. Defaults to True.

    Return:
        encoded boxes (Tensor): Boxes with shape (N, 4) or (B, N, 4)
    """
    if not isinstance(normalizer, float):
        normalizer = torch.tensor(normalizer, device=priors.device)
        assert len(normalizer) == 4, 'Normalizer must have length = 4'
    assert priors.size(0) == tblr.size(0)
    if priors.ndim == 3:
        assert priors.size(1) == tblr.size(1)

    loc_decode = tblr * normalizer
    prior_centers = (priors[..., 0:2] + priors[..., 2:4]) / 2
    if normalize_by_wh:
        wh = priors[..., 2:4] - priors[..., 0:2]
        w, h = torch.split(wh, 1, dim=-1)
        # Inplace operation with slice would failed for exporting to ONNX
        th = h * loc_decode[..., :2]  # tb
        tw = w * loc_decode[..., 2:]  # lr
        loc_decode = torch.cat([th, tw], dim=-1)
    # Cannot be exported using onnx when loc_decode.split(1, dim=-1)
    top, bottom, left, right = loc_decode.split((1, 1, 1, 1), dim=-1)
    xmin = prior_centers[..., 0].unsqueeze(-1) - left
    xmax = prior_centers[..., 0].unsqueeze(-1) + right
    ymin = prior_centers[..., 1].unsqueeze(-1) - top
    ymax = prior_centers[..., 1].unsqueeze(-1) + bottom

    bboxes = torch.cat((xmin, ymin, xmax, ymax), dim=-1)

    if clip_border and max_shape is not None:
        # clip bboxes with dynamic `min` and `max` for onnx
        if torch.onnx.is_in_onnx_export():
            from mmdet.core.export import dynamic_clip_for_onnx
            xmin, ymin, xmax, ymax = dynamic_clip_for_onnx(
                xmin, ymin, xmax, ymax, max_shape)
            bboxes = torch.cat([xmin, ymin, xmax, ymax], dim=-1)
            return bboxes
        if not isinstance(max_shape, torch.Tensor):
            max_shape = priors.new_tensor(max_shape)
        max_shape = max_shape[..., :2].type_as(priors)
        if max_shape.ndim == 2:
            assert bboxes.ndim == 3
            assert max_shape.size(0) == bboxes.size(0)

        min_xy = priors.new_tensor(0)
        max_xy = torch.cat([max_shape, max_shape],
                           dim=-1).flip(-1).unsqueeze(-2)
        bboxes = torch.where(bboxes < min_xy, min_xy, bboxes)
        bboxes = torch.where(bboxes > max_xy, max_xy, bboxes)

    return bboxes