|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch
|
|
|
|
|
|
from ..utils import box_utils
|
|
|
|
|
|
class MultiboxLoss(nn.Module):
|
|
def __init__(self, priors, iou_threshold, neg_pos_ratio,
|
|
center_variance, size_variance, device):
|
|
"""Implement SSD Multibox Loss.
|
|
|
|
Basically, Multibox loss combines classification loss
|
|
and Smooth L1 regression loss.
|
|
"""
|
|
super(MultiboxLoss, self).__init__()
|
|
self.iou_threshold = iou_threshold
|
|
self.neg_pos_ratio = neg_pos_ratio
|
|
self.center_variance = center_variance
|
|
self.size_variance = size_variance
|
|
self.priors = priors
|
|
self.priors.to(device)
|
|
|
|
def forward(self, confidence, predicted_locations, labels, gt_locations):
|
|
"""Compute classification loss and smooth l1 loss.
|
|
|
|
Args:
|
|
confidence (batch_size, num_priors, num_classes): class predictions.
|
|
locations (batch_size, num_priors, 4): predicted locations.
|
|
labels (batch_size, num_priors): real labels of all the priors.
|
|
boxes (batch_size, num_priors, 4): real boxes corresponding all the priors.
|
|
"""
|
|
num_classes = confidence.size(2)
|
|
with torch.no_grad():
|
|
|
|
loss = -F.log_softmax(confidence, dim=2)[:, :, 0]
|
|
mask = box_utils.hard_negative_mining(loss, labels, self.neg_pos_ratio)
|
|
|
|
confidence = confidence[mask, :]
|
|
classification_loss = F.cross_entropy(confidence.reshape(-1, num_classes), labels[mask], size_average=False)
|
|
pos_mask = labels > 0
|
|
predicted_locations = predicted_locations[pos_mask, :].reshape(-1, 4)
|
|
gt_locations = gt_locations[pos_mask, :].reshape(-1, 4)
|
|
smooth_l1_loss = F.smooth_l1_loss(predicted_locations, gt_locations, size_average=False)
|
|
num_pos = gt_locations.size(0)
|
|
return smooth_l1_loss/num_pos, classification_loss/num_pos
|
|
|