Spaces:
No application file
No application file
File size: 633 Bytes
6755a2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
import torch
import torch.nn as nn
class TVLoss(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""_summary_
Args:
x (torch.Tensor): b * c * h * w
Returns:
torch.Tensor: _description_
"""
b, c, h, w = x.shape
count_h = b * (h - 1) * w * c
count_w = b * h * (w - 1) * c
h_tv = (torch.pow((x[:,:,1:,:]-x[:,:,:-1,:]),2) / count_h).sum()
w_tv = (torch.pow((x[:,:,:,1:]-x[:,:,:,:-1]),2) / count_w).sum()
loss = 2 * (h_tv + w_tv)
return loss
|