File size: 904 Bytes
b416439 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: pixelshuffle.py
# Created Date: Friday July 1st 2022
# Author: Chen Xuanhong
# Email: [email protected]
# Last Modified: Friday, 1st July 2022 10:18:39 am
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import torch.nn as nn
def pixelshuffle_block(
in_channels, out_channels, upscale_factor=2, kernel_size=3, bias=False
):
"""
Upsample features according to `upscale_factor`.
"""
padding = kernel_size // 2
conv = nn.Conv2d(
in_channels,
out_channels * (upscale_factor**2),
kernel_size,
padding=1,
bias=bias,
)
pixel_shuffle = nn.PixelShuffle(upscale_factor)
return nn.Sequential(*[conv, pixel_shuffle])
|