import torch import torch.nn as nn from torch.distributions.kl import kl_divergence from torch.distributions.normal import Normal from torch.nn.functional import relu class BatchHardTripletLoss(nn.Module): def __init__(self, margin=1., squared=False, agg='sum'): """ Initalize the loss function with a margin parameter, whether or not to consider squared Euclidean distance and how to aggregate the loss in a batch """ super().__init__() self.margin = margin self.squared = squared self.agg = agg self.eps = 1e-8 def get_pairwise_distances(self, embeddings): """ Computing Euclidean distance for all possible pairs of embeddings. """ ab = embeddings.mm(embeddings.t()) a_squared = ab.diag().unsqueeze(1) b_squared = ab.diag().unsqueeze(0) distances = a_squared - 2 * ab + b_squared distances = relu(distances) if not self.squared: distances = torch.sqrt(distances + self.eps) return distances def hardest_triplet_mining(self, dist_mat, labels): assert len(dist_mat.size()) == 2 assert dist_mat.size(0) == dist_mat.size(1) N = dist_mat.size(0) is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()) dist_ap, relative_p_inds = torch.max( (dist_mat * is_pos), 1, keepdim=True) dist_an, relative_n_inds = torch.min( (dist_mat * is_neg), 1, keepdim=True) return dist_ap, dist_an def forward(self, embeddings, labels): distances = self.get_pairwise_distances(embeddings) dist_ap, dist_an = self.hardest_triplet_mining(distances, labels) triplet_loss = relu(dist_ap - dist_an + self.margin).sum() return triplet_loss class VAELoss(nn.Module): def __init__(self): super().__init__() self.reconstruction_loss = nn.BCELoss(reduction='sum') def kl_divergence_loss(self, q_dist): return kl_divergence( q_dist, Normal(torch.zeros_like(q_dist.mean), torch.ones_like(q_dist.stddev)) ).sum(-1) def forward(self, output, target, encoding): loss = self.kl_divergence_loss(encoding).sum() + self.reconstruction_loss(output, target) return loss