Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
import torch.nn as nn | |
class AdaptiveWingLoss(nn.Module): | |
"""Adaptive wing loss. paper ref: 'Adaptive Wing Loss for Robust Face | |
Alignment via Heatmap Regression' Wang et al. ICCV'2019. | |
Args: | |
alpha (float), omega (float), epsilon (float), theta (float) | |
are hyper-parameters. | |
use_target_weight (bool): Option to use weighted MSE loss. | |
Different joint types may have different target weights. | |
loss_weight (float): Weight of the loss. Default: 1.0. | |
""" | |
def __init__(self, | |
alpha=2.1, | |
omega=14, | |
epsilon=1, | |
theta=0.5, | |
use_target_weight=False, | |
loss_weight=1.): | |
super().__init__() | |
self.alpha = float(alpha) | |
self.omega = float(omega) | |
self.epsilon = float(epsilon) | |
self.theta = float(theta) | |
self.use_target_weight = use_target_weight | |
self.loss_weight = loss_weight | |
def criterion(self, pred, target): | |
"""Criterion of wingloss. | |
Note: | |
batch_size: N | |
num_keypoints: K | |
Args: | |
pred (torch.Tensor[NxKxHxW]): Predicted heatmaps. | |
target (torch.Tensor[NxKxHxW]): Target heatmaps. | |
""" | |
H, W = pred.shape[2:4] | |
delta = (target - pred).abs() | |
A = self.omega * ( | |
1 / (1 + torch.pow(self.theta / self.epsilon, self.alpha - target)) | |
) * (self.alpha - target) * (torch.pow( | |
self.theta / self.epsilon, | |
self.alpha - target - 1)) * (1 / self.epsilon) | |
C = self.theta * A - self.omega * torch.log( | |
1 + torch.pow(self.theta / self.epsilon, self.alpha - target)) | |
losses = torch.where( | |
delta < self.theta, | |
self.omega * | |
torch.log(1 + | |
torch.pow(delta / self.epsilon, self.alpha - target)), | |
A * delta - C) | |
return torch.mean(losses) | |
def forward(self, output, target, target_weight): | |
"""Forward function. | |
Note: | |
batch_size: N | |
num_keypoints: K | |
Args: | |
output (torch.Tensor[NxKxHxW]): Output heatmaps. | |
target (torch.Tensor[NxKxHxW]): Target heatmaps. | |
target_weight (torch.Tensor[NxKx1]): | |
Weights across different joint types. | |
""" | |
if self.use_target_weight: | |
loss = self.criterion(output * target_weight.unsqueeze(-1), | |
target * target_weight.unsqueeze(-1)) | |
else: | |
loss = self.criterion(output, target) | |
return loss * self.loss_weight | |