Spaces:
No application file
No application file
from typing import List, Union | |
import torch | |
import torch.nn as nn | |
from .multi_layer_loss import MultiLayerLoss | |
from ..flow.util import torch_wrap | |
class FlowShortTermLoss(nn.Module): | |
def __init__(self, loss_fn: nn.Module = nn.MSELoss) -> None: | |
super().__init__() | |
self.loss_fn = loss_fn | |
def forward(self, output, wrap): | |
b, c, h, w = output.shape | |
loss = self.loss_fn(output, wrap) | |
return loss | |
class FlowLongTermLoss(MultiLayerLoss): | |
def __init__(self, loss_fn: nn.Module, weights: List[float] = None) -> None: | |
super().__init__(loss_fn, weights) | |
def forward( | |
self, outputs: List[torch.tensor], wraps: List[torch.tensor] | |
) -> torch.Tensor: | |
"""_summary_ | |
Args: | |
output (torch.Tensor): b * c * h * w | |
Returns: | |
torch.Tensor: _description_ | |
""" | |
assert len(outputs) == len( | |
wraps | |
), f"length should be x({len(outputs)}) == target({len(wraps)})" | |
if self.weights is not None: | |
assert len(outputs) == len( | |
self.weights | |
), f"weights should be None or length of x({len(outputs)}) must be equal to target({len(self.weights)})" | |
total_loss = 0 | |
for i in len(outputs): | |
b, c, h, w = output.shape | |
output = outputs[i] | |
wrap = wraps[i] | |
loss = self.loss_fn(output, wrap) # mseloss reduce=mean | |
if self.weights is not None: | |
loss *= self.weights[i] | |
total_loss += loss | |
return total_loss | |