#!/usr/bin/env python3 | |
# -*- coding:utf-8 -*- | |
############################################################# | |
# File: OSAG.py | |
# Created Date: Tuesday April 28th 2022 | |
# Author: Chen Xuanhong | |
# Email: [email protected] | |
# Last Modified: Sunday, 23rd April 2023 3:08:49 pm | |
# Modified By: Chen Xuanhong | |
# Copyright (c) 2020 Shanghai Jiao Tong University | |
############################################################# | |
import torch.nn as nn | |
from .esa import ESA | |
from .OSA import OSA_Block | |
class OSAG(nn.Module): | |
def __init__( | |
self, | |
channel_num=64, | |
bias=True, | |
block_num=4, | |
ffn_bias=False, | |
window_size=0, | |
pe=False, | |
): | |
super(OSAG, self).__init__() | |
# print("window_size: %d" % (window_size)) | |
# print("with_pe", pe) | |
# print("ffn_bias: %d" % (ffn_bias)) | |
# block_script_name = kwargs.get("block_script_name", "OSA") | |
# block_class_name = kwargs.get("block_class_name", "OSA_Block") | |
# script_name = "." + block_script_name | |
# package = __import__(script_name, fromlist=True) | |
block_class = OSA_Block # getattr(package, block_class_name) | |
group_list = [] | |
for _ in range(block_num): | |
temp_res = block_class( | |
channel_num, | |
bias, | |
ffn_bias=ffn_bias, | |
window_size=window_size, | |
with_pe=pe, | |
) | |
group_list.append(temp_res) | |
group_list.append(nn.Conv2d(channel_num, channel_num, 1, 1, 0, bias=bias)) | |
self.residual_layer = nn.Sequential(*group_list) | |
esa_channel = max(channel_num // 4, 16) | |
self.esa = ESA(esa_channel, channel_num) | |
def forward(self, x): | |
out = self.residual_layer(x) | |
out = out + x | |
return self.esa(out) | |