Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from utils import * | |
""" | |
Class for loss for training YOLO model. | |
Argmunets: | |
h_coord: weight for loss related to coordinates and shapes of box | |
h__noobj: weight for loss of predicting presence of box when it is absent. | |
""" | |
class YOLOLoss(nn.Module): | |
def __init__(self, h_coord=0.5, h_noobj=2., h_shape=2., h_obj=10.): | |
super().__init__() | |
self.h_coord = h_coord | |
self.h_noobj = h_noobj | |
self.h_shape = h_shape | |
self.h_obj = h_obj | |
def square_error(self, output, target): | |
return (output - target) ** 2 | |
def forward(self, output, target): | |
pred_xy, pred_wh, pred_obj = yolo_head(output) | |
gt_xy, gt_wh, gt_obj = process_target(target) | |
pred_ul = pred_xy - 0.5 * pred_wh | |
pred_br = pred_xy + 0.5 * pred_wh | |
pred_area = pred_wh[..., 0] * pred_wh[..., 1] | |
gt_ul = gt_xy - 0.5 * gt_wh | |
gt_br = gt_xy + 0.5 * gt_wh | |
gt_area = gt_wh[..., 0] * gt_wh[..., 1] | |
intersect_ul = torch.max(pred_ul, gt_ul) | |
intersect_br = torch.min(pred_br, gt_br) | |
intersect_wh = intersect_br - intersect_ul | |
intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1] | |
iou = intersect_area / (pred_area + gt_area - intersect_area) | |
max_iou = torch.max(iou, dim=3, keepdim=True)[0] | |
best_box_index = torch.unsqueeze(torch.eq(iou, max_iou).float(), dim=-1) | |
gt_box_conf = best_box_index * gt_obj | |
xy_loss = (self.square_error(pred_xy, gt_xy) * gt_box_conf).sum() | |
wh_loss = (self.square_error(pred_wh, gt_wh) * gt_box_conf).sum() | |
obj_loss = (self.square_error(pred_obj, gt_obj) * gt_box_conf).sum() | |
noobj_loss = (self.square_error(pred_obj, gt_obj) * (1 - gt_box_conf)).sum() | |
total_loss = self.h_coord * xy_loss + self.h_shape * wh_loss + self.h_obj * obj_loss + self.h_noobj * noobj_loss | |
return total_loss |