File size: 1,584 Bytes
6755a2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from typing import List, Union

import torch
import torch.nn as nn
from .multi_layer_loss import MultiLayerLoss

from ..flow.util import torch_wrap


class FlowShortTermLoss(nn.Module):
    def __init__(self, loss_fn: nn.Module = nn.MSELoss) -> None:
        super().__init__()
        self.loss_fn = loss_fn

    def forward(self, output, wrap):
        b, c, h, w = output.shape
        loss = self.loss_fn(output, wrap)
        return loss


class FlowLongTermLoss(MultiLayerLoss):
    def __init__(self, loss_fn: nn.Module, weights: List[float] = None) -> None:
        super().__init__(loss_fn, weights)

    def forward(
        self, outputs: List[torch.tensor], wraps: List[torch.tensor]
    ) -> torch.Tensor:
        """_summary_

        Args:
            output (torch.Tensor): b * c * h * w

        Returns:
            torch.Tensor: _description_
        """
        assert len(outputs) == len(
            wraps
        ), f"length should be x({len(outputs)}) == target({len(wraps)})"
        if self.weights is not None:
            assert len(outputs) == len(
                self.weights
            ), f"weights should be None or length of x({len(outputs)}) must be equal to target({len(self.weights)})"

        total_loss = 0
        for i in len(outputs):
            b, c, h, w = output.shape
            output = outputs[i]
            wrap = wraps[i]
            loss = self.loss_fn(output, wrap)  # mseloss reduce=mean
            if self.weights is not None:
                loss *= self.weights[i]
            total_loss += loss
        return total_loss