Spaces:
Runtime error
Runtime error
File size: 2,422 Bytes
fd52b7f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import torch
import torch.nn as nn
from torch.distributions.kl import kl_divergence
from torch.distributions.normal import Normal
from torch.nn.functional import relu
class BatchHardTripletLoss(nn.Module):
def __init__(self, margin=1., squared=False, agg='sum'):
"""
Initalize the loss function with a margin parameter, whether or not to consider
squared Euclidean distance and how to aggregate the loss in a batch
"""
super().__init__()
self.margin = margin
self.squared = squared
self.agg = agg
self.eps = 1e-8
def get_pairwise_distances(self, embeddings):
"""
Computing Euclidean distance for all possible pairs of embeddings.
"""
ab = embeddings.mm(embeddings.t())
a_squared = ab.diag().unsqueeze(1)
b_squared = ab.diag().unsqueeze(0)
distances = a_squared - 2 * ab + b_squared
distances = relu(distances)
if not self.squared:
distances = torch.sqrt(distances + self.eps)
return distances
def hardest_triplet_mining(self, dist_mat, labels):
assert len(dist_mat.size()) == 2
assert dist_mat.size(0) == dist_mat.size(1)
N = dist_mat.size(0)
is_pos = labels.expand(N, N).eq(labels.expand(N, N).t())
is_neg = labels.expand(N, N).ne(labels.expand(N, N).t())
dist_ap, relative_p_inds = torch.max(
(dist_mat * is_pos), 1, keepdim=True)
dist_an, relative_n_inds = torch.min(
(dist_mat * is_neg), 1, keepdim=True)
return dist_ap, dist_an
def forward(self, embeddings, labels):
distances = self.get_pairwise_distances(embeddings)
dist_ap, dist_an = self.hardest_triplet_mining(distances, labels)
triplet_loss = relu(dist_ap - dist_an + self.margin).sum()
return triplet_loss
class VAELoss(nn.Module):
def __init__(self):
super().__init__()
self.reconstruction_loss = nn.BCELoss(reduction='sum')
def kl_divergence_loss(self, q_dist):
return kl_divergence(
q_dist, Normal(torch.zeros_like(q_dist.mean), torch.ones_like(q_dist.stddev))
).sum(-1)
def forward(self, output, target, encoding):
loss = self.kl_divergence_loss(encoding).sum() + self.reconstruction_loss(output, target)
return loss
|