File size: 1,301 Bytes
6755a2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

from typing import List, Union, Callable

import torch
import torch.nn as nn

from .multi_layer_loss import MultiLayerLoss


class ContentLoss(nn.Module):
    def __init__(self, model: nn.Module, loss_fn: nn.Module=nn.MSELoss, weights: List[float] = None, transform: Callable=None,) -> None:
        super().__init__()
        self.model = model
        self.loss_fn = loss_fn
        self.weights = weights
        self.transform = transform

    def forward(self, output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """_summary_

        Args:
            output (torch.Tensor): b * c * h * w

        Returns:
            torch.Tensor: _description_
        """

        if self.transform is not None:
            output = self.transform(output)
            target = self.transform(target)
        output_feature = self.model(output)
        target_feature = self.model(target)
        assert len(output_feature) == len(target_feature)
        keys = sorted(output_feature.keys())
        total_loss = 0
        for i, k in enumerate(keys):
            loss = self.loss_fn(output_feature[k], target_feature[k])
            print(i, k, loss)
            if self.weights is not None:
                loss *= self.weights[i]
            total_loss += loss
        return total_loss