sadjava's picture
changed to pipelines
fd52b7f
import torch
import torch.nn as nn
from utils import *
"""
Class for loss for training YOLO model.
Argmunets:
h_coord: weight for loss related to coordinates and shapes of box
h__noobj: weight for loss of predicting presence of box when it is absent.
"""
class YOLOLoss(nn.Module):
def __init__(self, h_coord=0.5, h_noobj=2., h_shape=2., h_obj=10.):
super().__init__()
self.h_coord = h_coord
self.h_noobj = h_noobj
self.h_shape = h_shape
self.h_obj = h_obj
def square_error(self, output, target):
return (output - target) ** 2
def forward(self, output, target):
pred_xy, pred_wh, pred_obj = yolo_head(output)
gt_xy, gt_wh, gt_obj = process_target(target)
pred_ul = pred_xy - 0.5 * pred_wh
pred_br = pred_xy + 0.5 * pred_wh
pred_area = pred_wh[..., 0] * pred_wh[..., 1]
gt_ul = gt_xy - 0.5 * gt_wh
gt_br = gt_xy + 0.5 * gt_wh
gt_area = gt_wh[..., 0] * gt_wh[..., 1]
intersect_ul = torch.max(pred_ul, gt_ul)
intersect_br = torch.min(pred_br, gt_br)
intersect_wh = intersect_br - intersect_ul
intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
iou = intersect_area / (pred_area + gt_area - intersect_area)
max_iou = torch.max(iou, dim=3, keepdim=True)[0]
best_box_index = torch.unsqueeze(torch.eq(iou, max_iou).float(), dim=-1)
gt_box_conf = best_box_index * gt_obj
xy_loss = (self.square_error(pred_xy, gt_xy) * gt_box_conf).sum()
wh_loss = (self.square_error(pred_wh, gt_wh) * gt_box_conf).sum()
obj_loss = (self.square_error(pred_obj, gt_obj) * gt_box_conf).sum()
noobj_loss = (self.square_error(pred_obj, gt_obj) * (1 - gt_box_conf)).sum()
total_loss = self.h_coord * xy_loss + self.h_shape * wh_loss + self.h_obj * obj_loss + self.h_noobj * noobj_loss
return total_loss