UTMOSv2 / utmosv2 /loss /losses.py
kAIto47802
Resolved conflict in README.md
b55d767
raw
history blame
1.22 kB
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
class PairwizeDiffLoss(nn.Module):
def __init__(self, margin: float = 0.2, norm: str = "l1"):
super().__init__()
self.margin = margin
self.norm = norm
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
s = input.unsqueeze(1) - input.unsqueeze(0)
t = target.unsqueeze(1) - target.unsqueeze(0)
if self.norm not in ["l1", "l2_squared"]:
raise ValueError(
f'Unknown norm: {self.norm}. Must be one of ["l1", "l2_squared"]'
)
norm_fn = {
"l1": torch.abs,
"l2_squared": lambda x: x**2,
}[self.norm]
loss = F.relu(norm_fn(s - t) - self.margin)
return loss.mean().div(2)
class CombinedLoss(nn.Module):
def __init__(self, weighted_losses: list[tuple[nn.Module, float]]):
super().__init__()
self.weighted_losses = weighted_losses
def forward(
self, input: torch.Tensor, target: torch.Tensor
) -> list[tuple[float, torch.Tensor]]:
return [(w, loss(input, target)) for loss, w in self.weighted_losses]