NeuCoSVC-Colab / modules /interpolation.py
kevinwang676's picture
Upload folder using huggingface_hub
cfdc687
raw
history blame contribute delete
753 Bytes
import torch
from modules.base import BaseModule
class InterpolationBlock(BaseModule):
def __init__(self, scale_factor, mode='linear', align_corners=False, downsample=False):
super(InterpolationBlock, self).__init__()
self.downsample = downsample
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
def forward(self, x):
outputs = torch.nn.functional.interpolate(
x,
size=x.shape[-1] * self.scale_factor \
if not self.downsample else x.shape[-1] // self.scale_factor,
mode=self.mode,
align_corners=self.align_corners,
recompute_scale_factor=False
)
return outputs