Spaces:
No application file
No application file
File size: 626 Bytes
6755a2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
from typing import List, Union
import torch
import torch.nn as nn
from .multi_layer_loss import MultiLayerLoss
class StyleLoss(MultiLayerLoss):
def __init__(self, loss_fn: nn.Module=nn.MSELoss, weights: List[float] = None) -> None:
super().__init__(loss_fn, weights)
def cal_single_layer_loss(self, x, y):
b, c, h, w = x.shape
loss = self.loss_fn(x, y) / (c * h * w) ** 2 / 4
return loss
def _get_feature(self, x: torch.Tensor) -> torch.Tensor:
b, c, h, w = x.shape
x = x.view(b, c, h * w)
gram = x.mul(x, x.transpose(1, 2))
return gram
|