Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
from torchvision.ops import deform_conv2d | |
class DeformableConv2d(nn.Module): | |
def __init__(self, | |
in_channels, | |
out_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
bias=False): | |
super(DeformableConv2d, self).__init__() | |
assert type(kernel_size) == tuple or type(kernel_size) == int | |
kernel_size = kernel_size if type(kernel_size) == tuple else (kernel_size, kernel_size) | |
self.stride = stride if type(stride) == tuple else (stride, stride) | |
self.padding = padding | |
self.offset_conv = nn.Conv2d(in_channels, | |
2 * kernel_size[0] * kernel_size[1], | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=self.padding, | |
bias=True) | |
nn.init.constant_(self.offset_conv.weight, 0.) | |
nn.init.constant_(self.offset_conv.bias, 0.) | |
self.modulator_conv = nn.Conv2d(in_channels, | |
1 * kernel_size[0] * kernel_size[1], | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=self.padding, | |
bias=True) | |
nn.init.constant_(self.modulator_conv.weight, 0.) | |
nn.init.constant_(self.modulator_conv.bias, 0.) | |
self.regular_conv = nn.Conv2d(in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=self.padding, | |
bias=bias) | |
def forward(self, x): | |
#h, w = x.shape[2:] | |
#max_offset = max(h, w)/4. | |
offset = self.offset_conv(x)#.clamp(-max_offset, max_offset) | |
modulator = 2. * torch.sigmoid(self.modulator_conv(x)) | |
x = deform_conv2d( | |
input=x, | |
offset=offset, | |
weight=self.regular_conv.weight, | |
bias=self.regular_conv.bias, | |
padding=self.padding, | |
mask=modulator, | |
stride=self.stride, | |
) | |
return x | |