File size: 1,516 Bytes
6755a2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

from typing import List, Union

import torch
import torch.nn as nn


class MultiLayerLoss(nn.Module):
    def __init__(self, loss_fn: nn.Module, weights: List[float]=None) -> None:
        super().__init__()
        self.weights = weights
        self.loss_fn = loss_fn

    def forward(self, output: Union[torch.Tensor, List[torch.tensor]], target: Union[torch.Tensor, List[torch.tensor]]) -> torch.Tensor:
        """_summary_

        Args:
            output (torch.Tensor): b * c * h * w

        Returns:
            torch.Tensor: _description_
        """
        if not isinstance(output, List):
            output = [output]
        if not isinstance(target, list):
            target = [target]
        assert len(output) == len(target), f"length of x({len(output)}) must be equal to target({len(target)})"
        if self.weights is not None:
            assert len(output) == len(self.weights), f"weights should be None or length of x({len(output)}) must be equal to weights({len(self.weights)})"

        total_loss = 0
        for i in range(len(output)):
            x = output[i]
            y = target[i]
            x = self._get_feature(x)
            y = self._get_feature(y)
            loss = self.loss_fn(x, y)
            if self.weights is not None:
                loss *= self.weights[i]
            total_loss += loss
        return total_loss

    def cal_single_layer_loss(self, x, y):
        raise NotImplementedError

    def _get_feature(self, x):
        raise NotImplementedError