Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import math | |
from typing import Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
from mmdet.models.losses.utils import weight_reduce_loss | |
from mmdet.structures.bbox import HorizontalBoxes | |
from mmyolo.registry import MODELS | |
def bbox_overlaps(pred: torch.Tensor, | |
target: torch.Tensor, | |
iou_mode: str = 'ciou', | |
bbox_format: str = 'xywh', | |
siou_theta: float = 4.0, | |
eps: float = 1e-7) -> torch.Tensor: | |
r"""Calculate overlap between two set of bboxes. | |
`Implementation of paper `Enhancing Geometric Factors into | |
Model Learning and Inference for Object Detection and Instance | |
Segmentation <https://arxiv.org/abs/2005.03572>`_. | |
In the CIoU implementation of YOLOv5 and MMDetection, there is a slight | |
difference in the way the alpha parameter is computed. | |
mmdet version: | |
alpha = (ious > 0.5).float() * v / (1 - ious + v) | |
YOLOv5 version: | |
alpha = v / (v - ious + (1 + eps) | |
Args: | |
pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2) | |
or (x, y, w, h),shape (n, 4). | |
target (Tensor): Corresponding gt bboxes, shape (n, 4). | |
iou_mode (str): Options are ('iou', 'ciou', 'giou', 'siou'). | |
Defaults to "ciou". | |
bbox_format (str): Options are "xywh" and "xyxy". | |
Defaults to "xywh". | |
siou_theta (float): siou_theta for SIoU when calculate shape cost. | |
Defaults to 4.0. | |
eps (float): Eps to avoid log(0). | |
Returns: | |
Tensor: shape (n, ). | |
""" | |
assert iou_mode in ('iou', 'ciou', 'giou', 'siou') | |
assert bbox_format in ('xyxy', 'xywh') | |
if bbox_format == 'xywh': | |
pred = HorizontalBoxes.cxcywh_to_xyxy(pred) | |
target = HorizontalBoxes.cxcywh_to_xyxy(target) | |
bbox1_x1, bbox1_y1 = pred[..., 0], pred[..., 1] | |
bbox1_x2, bbox1_y2 = pred[..., 2], pred[..., 3] | |
bbox2_x1, bbox2_y1 = target[..., 0], target[..., 1] | |
bbox2_x2, bbox2_y2 = target[..., 2], target[..., 3] | |
# Overlap | |
overlap = (torch.min(bbox1_x2, bbox2_x2) - | |
torch.max(bbox1_x1, bbox2_x1)).clamp(0) * \ | |
(torch.min(bbox1_y2, bbox2_y2) - | |
torch.max(bbox1_y1, bbox2_y1)).clamp(0) | |
# Union | |
w1, h1 = bbox1_x2 - bbox1_x1, bbox1_y2 - bbox1_y1 | |
w2, h2 = bbox2_x2 - bbox2_x1, bbox2_y2 - bbox2_y1 | |
union = (w1 * h1) + (w2 * h2) - overlap + eps | |
h1 = bbox1_y2 - bbox1_y1 + eps | |
h2 = bbox2_y2 - bbox2_y1 + eps | |
# IoU | |
ious = overlap / union | |
# enclose area | |
enclose_x1y1 = torch.min(pred[..., :2], target[..., :2]) | |
enclose_x2y2 = torch.max(pred[..., 2:], target[..., 2:]) | |
enclose_wh = (enclose_x2y2 - enclose_x1y1).clamp(min=0) | |
enclose_w = enclose_wh[..., 0] # cw | |
enclose_h = enclose_wh[..., 1] # ch | |
if iou_mode == 'ciou': | |
# CIoU = IoU - ( (ρ^2(b_pred,b_gt) / c^2) + (alpha x v) ) | |
# calculate enclose area (c^2) | |
enclose_area = enclose_w**2 + enclose_h**2 + eps | |
# calculate ρ^2(b_pred,b_gt): | |
# euclidean distance between b_pred(bbox2) and b_gt(bbox1) | |
# center point, because bbox format is xyxy -> left-top xy and | |
# right-bottom xy, so need to / 4 to get center point. | |
rho2_left_item = ((bbox2_x1 + bbox2_x2) - (bbox1_x1 + bbox1_x2))**2 / 4 | |
rho2_right_item = ((bbox2_y1 + bbox2_y2) - | |
(bbox1_y1 + bbox1_y2))**2 / 4 | |
rho2 = rho2_left_item + rho2_right_item # rho^2 (ρ^2) | |
# Width and height ratio (v) | |
wh_ratio = (4 / (math.pi**2)) * torch.pow( | |
torch.atan(w2 / h2) - torch.atan(w1 / h1), 2) | |
with torch.no_grad(): | |
alpha = wh_ratio / (wh_ratio - ious + (1 + eps)) | |
# CIoU | |
ious = ious - ((rho2 / enclose_area) + (alpha * wh_ratio)) | |
elif iou_mode == 'giou': | |
# GIoU = IoU - ( (A_c - union) / A_c ) | |
convex_area = enclose_w * enclose_h + eps # convex area (A_c) | |
ious = ious - (convex_area - union) / convex_area | |
elif iou_mode == 'siou': | |
# SIoU: https://arxiv.org/pdf/2205.12740.pdf | |
# SIoU = IoU - ( (Distance Cost + Shape Cost) / 2 ) | |
# calculate sigma (σ): | |
# euclidean distance between bbox2(pred) and bbox1(gt) center point, | |
# sigma_cw = b_cx_gt - b_cx | |
sigma_cw = (bbox2_x1 + bbox2_x2) / 2 - (bbox1_x1 + bbox1_x2) / 2 + eps | |
# sigma_ch = b_cy_gt - b_cy | |
sigma_ch = (bbox2_y1 + bbox2_y2) / 2 - (bbox1_y1 + bbox1_y2) / 2 + eps | |
# sigma = √( (sigma_cw ** 2) - (sigma_ch ** 2) ) | |
sigma = torch.pow(sigma_cw**2 + sigma_ch**2, 0.5) | |
# choose minimize alpha, sin(alpha) | |
sin_alpha = torch.abs(sigma_ch) / sigma | |
sin_beta = torch.abs(sigma_cw) / sigma | |
sin_alpha = torch.where(sin_alpha <= math.sin(math.pi / 4), sin_alpha, | |
sin_beta) | |
# Angle cost = 1 - 2 * ( sin^2 ( arcsin(x) - (pi / 4) ) ) | |
angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - math.pi / 2) | |
# Distance cost = Σ_(t=x,y) (1 - e ^ (- γ ρ_t)) | |
rho_x = (sigma_cw / enclose_w)**2 # ρ_x | |
rho_y = (sigma_ch / enclose_h)**2 # ρ_y | |
gamma = 2 - angle_cost # γ | |
distance_cost = (1 - torch.exp(-1 * gamma * rho_x)) + ( | |
1 - torch.exp(-1 * gamma * rho_y)) | |
# Shape cost = Ω = Σ_(t=w,h) ( ( 1 - ( e ^ (-ω_t) ) ) ^ θ ) | |
omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2) # ω_w | |
omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2) # ω_h | |
shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), | |
siou_theta) + torch.pow( | |
1 - torch.exp(-1 * omiga_h), siou_theta) | |
ious = ious - ((distance_cost + shape_cost) * 0.5) | |
return ious.clamp(min=-1.0, max=1.0) | |
class IoULoss(nn.Module): | |
"""IoULoss. | |
Computing the IoU loss between a set of predicted bboxes and target bboxes. | |
Args: | |
iou_mode (str): Options are "ciou". | |
Defaults to "ciou". | |
bbox_format (str): Options are "xywh" and "xyxy". | |
Defaults to "xywh". | |
eps (float): Eps to avoid log(0). | |
reduction (str): Options are "none", "mean" and "sum". | |
loss_weight (float): Weight of loss. | |
return_iou (bool): If True, return loss and iou. | |
""" | |
def __init__(self, | |
iou_mode: str = 'ciou', | |
bbox_format: str = 'xywh', | |
eps: float = 1e-7, | |
reduction: str = 'mean', | |
loss_weight: float = 1.0, | |
return_iou: bool = True): | |
super().__init__() | |
assert bbox_format in ('xywh', 'xyxy') | |
assert iou_mode in ('ciou', 'siou', 'giou') | |
self.iou_mode = iou_mode | |
self.bbox_format = bbox_format | |
self.eps = eps | |
self.reduction = reduction | |
self.loss_weight = loss_weight | |
self.return_iou = return_iou | |
def forward( | |
self, | |
pred: torch.Tensor, | |
target: torch.Tensor, | |
weight: Optional[torch.Tensor] = None, | |
avg_factor: Optional[float] = None, | |
reduction_override: Optional[Union[str, bool]] = None | |
) -> Tuple[Union[torch.Tensor, torch.Tensor], torch.Tensor]: | |
"""Forward function. | |
Args: | |
pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2) | |
or (x, y, w, h),shape (n, 4). | |
target (Tensor): Corresponding gt bboxes, shape (n, 4). | |
weight (Tensor, optional): Element-wise weights. | |
avg_factor (float, optional): Average factor when computing the | |
mean of losses. | |
reduction_override (str, bool, optional): Same as built-in losses | |
of PyTorch. Defaults to None. | |
Returns: | |
loss or tuple(loss, iou): | |
""" | |
if weight is not None and not torch.any(weight > 0): | |
if pred.dim() == weight.dim() + 1: | |
weight = weight.unsqueeze(1) | |
return (pred * weight).sum() # 0 | |
assert reduction_override in (None, 'none', 'mean', 'sum') | |
reduction = ( | |
reduction_override if reduction_override else self.reduction) | |
if weight is not None and weight.dim() > 1: | |
weight = weight.mean(-1) | |
iou = bbox_overlaps( | |
pred, | |
target, | |
iou_mode=self.iou_mode, | |
bbox_format=self.bbox_format, | |
eps=self.eps) | |
loss = self.loss_weight * weight_reduce_loss(1.0 - iou, weight, | |
reduction, avg_factor) | |
if self.return_iou: | |
return loss, iou | |
else: | |
return loss | |