kevinwang676's picture
Upload folder using huggingface_hub
6755a2d verified
raw
history blame
633 Bytes
import torch
import torch.nn as nn
class TVLoss(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""_summary_
Args:
x (torch.Tensor): b * c * h * w
Returns:
torch.Tensor: _description_
"""
b, c, h, w = x.shape
count_h = b * (h - 1) * w * c
count_w = b * h * (w - 1) * c
h_tv = (torch.pow((x[:,:,1:,:]-x[:,:,:-1,:]),2) / count_h).sum()
w_tv = (torch.pow((x[:,:,:,1:]-x[:,:,:,:-1]),2) / count_w).sum()
loss = 2 * (h_tv + w_tv)
return loss