Spaces:
Running
on
Zero
Running
on
Zero
from __future__ import annotations | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class PairwizeDiffLoss(nn.Module): | |
def __init__(self, margin: float = 0.2, norm: str = "l1"): | |
super().__init__() | |
self.margin = margin | |
self.norm = norm | |
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
s = input.unsqueeze(1) - input.unsqueeze(0) | |
t = target.unsqueeze(1) - target.unsqueeze(0) | |
if self.norm not in ["l1", "l2_squared"]: | |
raise ValueError( | |
f'Unknown norm: {self.norm}. Must be one of ["l1", "l2_squared"]' | |
) | |
norm_fn = { | |
"l1": torch.abs, | |
"l2_squared": lambda x: x**2, | |
}[self.norm] | |
loss = F.relu(norm_fn(s - t) - self.margin) | |
return loss.mean().div(2) | |
class CombinedLoss(nn.Module): | |
def __init__(self, weighted_losses: list[tuple[nn.Module, float]]): | |
super().__init__() | |
self.weighted_losses = weighted_losses | |
def forward( | |
self, input: torch.Tensor, target: torch.Tensor | |
) -> list[tuple[float, torch.Tensor]]: | |
return [(w, loss(input, target)) for loss, w in self.weighted_losses] | |