KyanChen's picture
Upload 89 files
3094730
raw
history blame
8.8 kB
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from mmdet.models.losses.utils import weight_reduce_loss
from mmdet.structures.bbox import HorizontalBoxes
from mmyolo.registry import MODELS
def bbox_overlaps(pred: torch.Tensor,
target: torch.Tensor,
iou_mode: str = 'ciou',
bbox_format: str = 'xywh',
siou_theta: float = 4.0,
eps: float = 1e-7) -> torch.Tensor:
r"""Calculate overlap between two set of bboxes.
`Implementation of paper `Enhancing Geometric Factors into
Model Learning and Inference for Object Detection and Instance
Segmentation <https://arxiv.org/abs/2005.03572>`_.
In the CIoU implementation of YOLOv5 and MMDetection, there is a slight
difference in the way the alpha parameter is computed.
mmdet version:
alpha = (ious > 0.5).float() * v / (1 - ious + v)
YOLOv5 version:
alpha = v / (v - ious + (1 + eps)
Args:
pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2)
or (x, y, w, h),shape (n, 4).
target (Tensor): Corresponding gt bboxes, shape (n, 4).
iou_mode (str): Options are ('iou', 'ciou', 'giou', 'siou').
Defaults to "ciou".
bbox_format (str): Options are "xywh" and "xyxy".
Defaults to "xywh".
siou_theta (float): siou_theta for SIoU when calculate shape cost.
Defaults to 4.0.
eps (float): Eps to avoid log(0).
Returns:
Tensor: shape (n, ).
"""
assert iou_mode in ('iou', 'ciou', 'giou', 'siou')
assert bbox_format in ('xyxy', 'xywh')
if bbox_format == 'xywh':
pred = HorizontalBoxes.cxcywh_to_xyxy(pred)
target = HorizontalBoxes.cxcywh_to_xyxy(target)
bbox1_x1, bbox1_y1 = pred[..., 0], pred[..., 1]
bbox1_x2, bbox1_y2 = pred[..., 2], pred[..., 3]
bbox2_x1, bbox2_y1 = target[..., 0], target[..., 1]
bbox2_x2, bbox2_y2 = target[..., 2], target[..., 3]
# Overlap
overlap = (torch.min(bbox1_x2, bbox2_x2) -
torch.max(bbox1_x1, bbox2_x1)).clamp(0) * \
(torch.min(bbox1_y2, bbox2_y2) -
torch.max(bbox1_y1, bbox2_y1)).clamp(0)
# Union
w1, h1 = bbox1_x2 - bbox1_x1, bbox1_y2 - bbox1_y1
w2, h2 = bbox2_x2 - bbox2_x1, bbox2_y2 - bbox2_y1
union = (w1 * h1) + (w2 * h2) - overlap + eps
h1 = bbox1_y2 - bbox1_y1 + eps
h2 = bbox2_y2 - bbox2_y1 + eps
# IoU
ious = overlap / union
# enclose area
enclose_x1y1 = torch.min(pred[..., :2], target[..., :2])
enclose_x2y2 = torch.max(pred[..., 2:], target[..., 2:])
enclose_wh = (enclose_x2y2 - enclose_x1y1).clamp(min=0)
enclose_w = enclose_wh[..., 0] # cw
enclose_h = enclose_wh[..., 1] # ch
if iou_mode == 'ciou':
# CIoU = IoU - ( (ρ^2(b_pred,b_gt) / c^2) + (alpha x v) )
# calculate enclose area (c^2)
enclose_area = enclose_w**2 + enclose_h**2 + eps
# calculate ρ^2(b_pred,b_gt):
# euclidean distance between b_pred(bbox2) and b_gt(bbox1)
# center point, because bbox format is xyxy -> left-top xy and
# right-bottom xy, so need to / 4 to get center point.
rho2_left_item = ((bbox2_x1 + bbox2_x2) - (bbox1_x1 + bbox1_x2))**2 / 4
rho2_right_item = ((bbox2_y1 + bbox2_y2) -
(bbox1_y1 + bbox1_y2))**2 / 4
rho2 = rho2_left_item + rho2_right_item # rho^2 (ρ^2)
# Width and height ratio (v)
wh_ratio = (4 / (math.pi**2)) * torch.pow(
torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
with torch.no_grad():
alpha = wh_ratio / (wh_ratio - ious + (1 + eps))
# CIoU
ious = ious - ((rho2 / enclose_area) + (alpha * wh_ratio))
elif iou_mode == 'giou':
# GIoU = IoU - ( (A_c - union) / A_c )
convex_area = enclose_w * enclose_h + eps # convex area (A_c)
ious = ious - (convex_area - union) / convex_area
elif iou_mode == 'siou':
# SIoU: https://arxiv.org/pdf/2205.12740.pdf
# SIoU = IoU - ( (Distance Cost + Shape Cost) / 2 )
# calculate sigma (σ):
# euclidean distance between bbox2(pred) and bbox1(gt) center point,
# sigma_cw = b_cx_gt - b_cx
sigma_cw = (bbox2_x1 + bbox2_x2) / 2 - (bbox1_x1 + bbox1_x2) / 2 + eps
# sigma_ch = b_cy_gt - b_cy
sigma_ch = (bbox2_y1 + bbox2_y2) / 2 - (bbox1_y1 + bbox1_y2) / 2 + eps
# sigma = √( (sigma_cw ** 2) - (sigma_ch ** 2) )
sigma = torch.pow(sigma_cw**2 + sigma_ch**2, 0.5)
# choose minimize alpha, sin(alpha)
sin_alpha = torch.abs(sigma_ch) / sigma
sin_beta = torch.abs(sigma_cw) / sigma
sin_alpha = torch.where(sin_alpha <= math.sin(math.pi / 4), sin_alpha,
sin_beta)
# Angle cost = 1 - 2 * ( sin^2 ( arcsin(x) - (pi / 4) ) )
angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - math.pi / 2)
# Distance cost = Σ_(t=x,y) (1 - e ^ (- γ ρ_t))
rho_x = (sigma_cw / enclose_w)**2 # ρ_x
rho_y = (sigma_ch / enclose_h)**2 # ρ_y
gamma = 2 - angle_cost # γ
distance_cost = (1 - torch.exp(-1 * gamma * rho_x)) + (
1 - torch.exp(-1 * gamma * rho_y))
# Shape cost = Ω = Σ_(t=w,h) ( ( 1 - ( e ^ (-ω_t) ) ) ^ θ )
omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2) # ω_w
omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2) # ω_h
shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w),
siou_theta) + torch.pow(
1 - torch.exp(-1 * omiga_h), siou_theta)
ious = ious - ((distance_cost + shape_cost) * 0.5)
return ious.clamp(min=-1.0, max=1.0)
@MODELS.register_module()
class IoULoss(nn.Module):
"""IoULoss.
Computing the IoU loss between a set of predicted bboxes and target bboxes.
Args:
iou_mode (str): Options are "ciou".
Defaults to "ciou".
bbox_format (str): Options are "xywh" and "xyxy".
Defaults to "xywh".
eps (float): Eps to avoid log(0).
reduction (str): Options are "none", "mean" and "sum".
loss_weight (float): Weight of loss.
return_iou (bool): If True, return loss and iou.
"""
def __init__(self,
iou_mode: str = 'ciou',
bbox_format: str = 'xywh',
eps: float = 1e-7,
reduction: str = 'mean',
loss_weight: float = 1.0,
return_iou: bool = True):
super().__init__()
assert bbox_format in ('xywh', 'xyxy')
assert iou_mode in ('ciou', 'siou', 'giou')
self.iou_mode = iou_mode
self.bbox_format = bbox_format
self.eps = eps
self.reduction = reduction
self.loss_weight = loss_weight
self.return_iou = return_iou
def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
weight: Optional[torch.Tensor] = None,
avg_factor: Optional[float] = None,
reduction_override: Optional[Union[str, bool]] = None
) -> Tuple[Union[torch.Tensor, torch.Tensor], torch.Tensor]:
"""Forward function.
Args:
pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2)
or (x, y, w, h),shape (n, 4).
target (Tensor): Corresponding gt bboxes, shape (n, 4).
weight (Tensor, optional): Element-wise weights.
avg_factor (float, optional): Average factor when computing the
mean of losses.
reduction_override (str, bool, optional): Same as built-in losses
of PyTorch. Defaults to None.
Returns:
loss or tuple(loss, iou):
"""
if weight is not None and not torch.any(weight > 0):
if pred.dim() == weight.dim() + 1:
weight = weight.unsqueeze(1)
return (pred * weight).sum() # 0
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if weight is not None and weight.dim() > 1:
weight = weight.mean(-1)
iou = bbox_overlaps(
pred,
target,
iou_mode=self.iou_mode,
bbox_format=self.bbox_format,
eps=self.eps)
loss = self.loss_weight * weight_reduce_loss(1.0 - iou, weight,
reduction, avg_factor)
if self.return_iou:
return loss, iou
else:
return loss