File size: 1,956 Bytes
fd52b7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
import torch.nn as nn
from utils import *

"""
    Class for loss for training YOLO model.

    Argmunets:
    h_coord: weight for loss related to coordinates and shapes of box
    h__noobj: weight for loss of predicting presence of box when it is absent.
"""
class YOLOLoss(nn.Module):
    def __init__(self, h_coord=0.5, h_noobj=2., h_shape=2., h_obj=10.):
        super().__init__()
        self.h_coord = h_coord
        self.h_noobj = h_noobj
        self.h_shape = h_shape
        self.h_obj = h_obj
    
    def square_error(self, output, target):
        return (output - target) ** 2

    def forward(self, output, target):
        
        pred_xy, pred_wh, pred_obj = yolo_head(output)
        gt_xy, gt_wh, gt_obj = process_target(target)

        pred_ul = pred_xy - 0.5 * pred_wh
        pred_br = pred_xy + 0.5 * pred_wh
        pred_area = pred_wh[..., 0] * pred_wh[..., 1]

        gt_ul = gt_xy - 0.5 * gt_wh
        gt_br = gt_xy + 0.5 * gt_wh
        gt_area = gt_wh[..., 0] * gt_wh[..., 1]

        intersect_ul = torch.max(pred_ul, gt_ul)
        intersect_br = torch.min(pred_br, gt_br)
        intersect_wh = intersect_br - intersect_ul
        intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]

        iou = intersect_area / (pred_area + gt_area - intersect_area)
        max_iou = torch.max(iou, dim=3, keepdim=True)[0]
        best_box_index = torch.unsqueeze(torch.eq(iou, max_iou).float(), dim=-1)
        gt_box_conf = best_box_index * gt_obj

        xy_loss = (self.square_error(pred_xy, gt_xy) * gt_box_conf).sum()
        wh_loss = (self.square_error(pred_wh, gt_wh) * gt_box_conf).sum()
        obj_loss = (self.square_error(pred_obj, gt_obj) * gt_box_conf).sum()
        noobj_loss = (self.square_error(pred_obj, gt_obj) * (1 - gt_box_conf)).sum()

        total_loss = self.h_coord * xy_loss + self.h_shape * wh_loss + self.h_obj * obj_loss + self.h_noobj * noobj_loss
        return total_loss