""" backbone.py - Contains the backbone of the model. (It is based on LPIENet and CURL's backbone) Perceptual Image Enhancement for Smartphone Real-Time Applications https://github.com/mv-lab/AISP CURL: Neural Curve Layers for Global Image Enhancement https://github.com/sjmoran/CURL David Serrano (dserrano@cvc.uab.cat) May 2024 """ import torch import torch.nn as nn import torch.nn.functional as F from typing import List class AttentionBlock(nn.Module): def __init__(self, dim: int): super(AttentionBlock, self).__init__() self._spatial_attention_conv = nn.Conv2d(2, dim, kernel_size=3, padding=1) # Channel attention MLP self._channel_attention_conv0 = nn.Conv2d(1, dim, kernel_size=1, padding=0) self._channel_attention_conv1 = nn.Conv2d(dim, dim, kernel_size=1, padding=0) self._out_conv = nn.Conv2d(2 * dim, dim, kernel_size=1, padding=0) def forward(self, x: torch.Tensor): if len(x.shape) != 4: raise ValueError(f"Expected [B, C, H, W] input, got {x.shape}.") # Spatial attention mean = torch.mean(x, dim=1, keepdim=True) # Mean/Max on C axis max, _ = torch.max(x, dim=1, keepdim=True) spatial_attention = torch.cat([mean, max], dim=1) # [B, 2, H, W] spatial_attention = self._spatial_attention_conv(spatial_attention) spatial_attention = torch.sigmoid(spatial_attention) * x # NOTE: This differs from CBAM as it uses Channel pooling, not spatial pooling! # In a way, this is 2x spatial attention channel_attention = torch.relu(self._channel_attention_conv0(mean)) channel_attention = self._channel_attention_conv1(channel_attention) channel_attention = torch.sigmoid(channel_attention) * x attention = torch.cat([spatial_attention, channel_attention], dim=1) # [B, 2*dim, H, W] attention = self._out_conv(attention) return x + attention class InverseBlock(nn.Module): def __init__(self, input_channels: int, channels: int): super(InverseBlock, self).__init__() self._conv0 = nn.Conv2d(input_channels, channels, kernel_size=1) self._dw_conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1, groups=channels) self._conv1 = nn.Conv2d(channels, channels, kernel_size=1) self._conv2 = nn.Conv2d(input_channels, channels, kernel_size=1) def forward(self, x: torch.Tensor): features = self._conv0(x) features = F.elu(self._dw_conv(features)) features = self._conv1(features) x = torch.relu(self._conv2(x)) return x + features class BaseBlock(nn.Module): def __init__(self, channels: int): super(BaseBlock, self).__init__() self._conv0 = nn.Conv2d(channels, channels, kernel_size=1) self._dw_conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1, groups=channels) self._conv1 = nn.Conv2d(channels, channels, kernel_size=1) self._conv2 = nn.Conv2d(channels, channels, kernel_size=1) self._conv3 = nn.Conv2d(channels, channels, kernel_size=1) def forward(self, x: torch.Tensor): features = self._conv0(x) features = F.elu(self._dw_conv(features)) features = self._conv1(features) x = x + features features = F.elu(self._conv2(x)) features = self._conv3(features) return x + features class AttentionTail(nn.Module): def __init__(self, channels: int): super(AttentionTail, self).__init__() self._conv0 = nn.Conv2d(channels, channels, kernel_size=7, padding=3) self._conv1 = nn.Conv2d(channels, channels, kernel_size=5, padding=2) self._conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) def forward(self, x: torch.Tensor): attention = torch.relu(self._conv0(x)) attention = torch.relu(self._conv1(attention)) attention = torch.sigmoid(self._conv2(attention)) return x * attention class Flatten(nn.Module): def forward(self, x): """Flatten a Tensor to a Vector :param x: Tensor :returns: 1D Tensor :rtype: Tensor """ return x.view(x.size()[0], -1) class ResidualConnection(nn.Module): def __init__(self, in_channels): super(ResidualConnection, self).__init__() self.in_channels = in_channels self.midnet2 = nn.Sequential( nn.Conv2d(in_channels, 64, 3, 1, 2, 2), nn.LeakyReLU(), nn.Conv2d(64, 64, 3, 1, 2, 2), nn.LeakyReLU() ) self.midnet4 = nn.Sequential( nn.Conv2d(in_channels, 64, 3, 1, 4, 4), nn.LeakyReLU(), nn.Conv2d(64, 64, 3, 1, 4, 4), nn.LeakyReLU() ) self.globnet = nn.Sequential( nn.Conv2d(in_channels, 64, 3, 2, 1, 1), nn.LeakyReLU(), nn.MaxPool2d(kernel_size=3, stride=2, padding=1), nn.Conv2d(64, 64, 3, 2, 1, 1), nn.LeakyReLU(), nn.MaxPool2d(kernel_size=3, stride=2, padding=1), nn.Conv2d(64, 64, 3, 2, 1, 1), nn.LeakyReLU(), nn.AdaptiveAvgPool2d(1), Flatten(), nn.Dropout(0.5), nn.Linear(64, 64) ) self.conv_fuse = nn.Conv2d(in_channels=192+in_channels, out_channels=in_channels, kernel_size=1) def forward(self, x): x_midnet2 = self.midnet2(x) x_midnet4 = self.midnet4(x) x_global = self.globnet(x).unsqueeze(2).unsqueeze(3) x_global = x_global.repeat(1, 1, x_midnet2.shape[2], x_midnet2.shape[3]) x_fuse = torch.cat((x, x_midnet2, x_midnet4, x_global), dim=1) x_out = self.conv_fuse(x_fuse) return x_out class Backbone(nn.Module): def __init__(self, input_channels: int, output_channels: int, encoder_dims: List[int], decoder_dims: List[int]): super(Backbone, self).__init__() if len(encoder_dims) != len(decoder_dims) + 1 or len(decoder_dims) < 1: raise ValueError(f"Unexpected encoder and decoder dims: {encoder_dims}, {decoder_dims}.") if input_channels != output_channels: raise NotImplementedError() encoders = [] for i, encoder_dim in enumerate(encoder_dims): input_dim = input_channels if i == 0 else encoder_dims[i - 1] encoders.append( nn.Sequential( nn.Conv2d(input_dim, encoder_dim, kernel_size=3, padding=1), BaseBlock(encoder_dim), BaseBlock(encoder_dim), AttentionBlock(encoder_dim), ) ) self._encoders = nn.ModuleList(encoders) decoders = [] for i, decoder_dim in enumerate(decoder_dims): input_dim = encoder_dims[-1] if i == 0 else decoder_dims[i - 1] + encoder_dims[-i - 1] decoders.append( nn.Sequential( nn.Conv2d(input_dim, decoder_dim, kernel_size=3, padding=1), BaseBlock(decoder_dim), BaseBlock(decoder_dim), AttentionBlock(decoder_dim), ) ) self._decoders = nn.ModuleList(decoders) self._inverse_bock = InverseBlock(encoder_dims[0] + decoder_dims[-1], output_channels) self._attention_tail = AttentionTail(output_channels) residual_connections = [] for i, decoder_dim in enumerate(encoder_dims): residual_connections.append( ResidualConnection(in_channels=decoder_dim) ) self._residual_connections = nn.ModuleList(residual_connections) def forward(self, x: torch.Tensor): if len(x.shape) != 4: raise ValueError(f"Expected [B, C, H, W] input, got {x.shape}.") global_residual = x encoder_outputs, residual_connections = [], [] for i, encoder in enumerate(self._encoders): x = encoder(x) if i != len(self._encoders) - 1: encoder_outputs.append(x) residual_connections.append(self._residual_connections[i](x)) x = F.max_pool2d(x, kernel_size=2) encoder_outputs.reverse() residual_connections.reverse() for i, decoder in enumerate(self._decoders): x = decoder(x) x = nn.Upsample(size=encoder_outputs[i].shape[2:], mode='bilinear', align_corners=False)(x) x = torch.cat([x, residual_connections[i]], dim=1) x = self._inverse_bock(x) x = self._attention_tail(x) return torch.clip(x + global_residual, 0, 1)