import torch from modules.base import BaseModule class InterpolationBlock(BaseModule): def __init__(self, scale_factor, mode='linear', align_corners=False, downsample=False): super(InterpolationBlock, self).__init__() self.downsample = downsample self.scale_factor = scale_factor self.mode = mode self.align_corners = align_corners def forward(self, x): outputs = torch.nn.functional.interpolate( x, size=x.shape[-1] * self.scale_factor \ if not self.downsample else x.shape[-1] // self.scale_factor, mode=self.mode, align_corners=self.align_corners, recompute_scale_factor=False ) return outputs