Spaces:
No application file
No application file
File size: 1,301 Bytes
6755a2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
from typing import List, Union, Callable
import torch
import torch.nn as nn
from .multi_layer_loss import MultiLayerLoss
class ContentLoss(nn.Module):
def __init__(self, model: nn.Module, loss_fn: nn.Module=nn.MSELoss, weights: List[float] = None, transform: Callable=None,) -> None:
super().__init__()
self.model = model
self.loss_fn = loss_fn
self.weights = weights
self.transform = transform
def forward(self, output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""_summary_
Args:
output (torch.Tensor): b * c * h * w
Returns:
torch.Tensor: _description_
"""
if self.transform is not None:
output = self.transform(output)
target = self.transform(target)
output_feature = self.model(output)
target_feature = self.model(target)
assert len(output_feature) == len(target_feature)
keys = sorted(output_feature.keys())
total_loss = 0
for i, k in enumerate(keys):
loss = self.loss_fn(output_feature[k], target_feature[k])
print(i, k, loss)
if self.weights is not None:
loss *= self.weights[i]
total_loss += loss
return total_loss
|