Spaces:
Running
on
T4
Running
on
T4
import torch | |
from torch import nn | |
def fuse_conv_and_bn(conv, bn): | |
# Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/ | |
fusedconv = ( | |
nn.Conv2d( | |
conv.in_channels, | |
conv.out_channels, | |
kernel_size=conv.kernel_size, | |
stride=conv.stride, | |
padding=conv.padding, | |
groups=conv.groups, | |
bias=True, | |
) | |
.requires_grad_(False) | |
.to(conv.weight.device) | |
) | |
# prepare filters | |
w_conv = conv.weight.clone().view(conv.out_channels, -1) | |
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) | |
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size())) | |
# prepare spatial bias | |
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias | |
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) | |
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) | |
return fusedconv | |
def copy_attr(a, b, include=(), exclude=()): | |
# Copy attributes from b to a, options to only include [...] and to exclude [...] | |
for k, v in b.__dict__.items(): | |
if (include and k not in include) or k.startswith("_") or k in exclude: | |
continue | |
setattr(a, k, v) | |