from functools import partial import torch from torch import nn import torch.nn.functional as F from .Advanced_Network_Helpers import * class Unet(nn.Module): def __init__( self, dim, init_dim=None, out_dim=None, dim_mults=(1, 2, 4, 8), channels=3, with_time_emb=True, resnet_block_groups=8, use_convnext=True, convnext_mult=2, ): super().__init__() # determine dimensions self.channels = channels # since we are concatenating the images and the conditionings along the channel dimension init_dim = default(init_dim, dim // 3 * 2) self.init_conv = nn.Conv2d(self.channels * 2, init_dim, 7, padding=3) self.conditioning_init = nn.Conv2d(self.channels, init_dim, 7, padding=3) dims = [init_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) self.in_out = in_out if use_convnext: block_klass = partial(ConvNextBlock, mult=convnext_mult) else: block_klass = partial(ResnetBlock, groups=resnet_block_groups) # time embeddings if with_time_emb: time_dim = dim * 4 self.time_mlp = nn.Sequential( SinusoidalPositionEmbeddings(dim), nn.Linear(dim, time_dim), nn.GELU(), nn.Linear(time_dim, time_dim), ) else: time_dim = None self.time_mlp = None # layers self.downs = nn.ModuleList([]) self.ups = nn.ModuleList([]) self.conditioning_encoder = nn.ModuleList([]) num_resolutions = len(in_out) self.num_resolutions = num_resolutions # conditioning encoder for ind, (dim_in, dim_out) in enumerate(in_out): is_last = ind >= (num_resolutions - 1) self.conditioning_encoder.append( nn.ModuleList( [ block_klass(dim_in, dim_out), Residual(PreNorm(dim_out, LinearAttention(dim_out))), Downsample(dim_out) if not is_last else nn.Identity(), ] ) ) for ind, (dim_in, dim_out) in enumerate(in_out): is_last = ind >= (num_resolutions - 1) self.downs.append( nn.ModuleList( [ block_klass(dim_in, dim_out, time_emb_dim=time_dim), block_klass(dim_out, dim_out, time_emb_dim=time_dim), Residual(PreNorm(dim_out, LinearAttention(dim_out))), Downsample(dim_out) if not is_last else nn.Identity(), ] ) ) mid_dim = dims[-1] self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) self.cross_attention_1 = Residual( PreNorm(mid_dim, LinearCrossAttention(mid_dim)) ) self.cross_attention_2 = Residual( PreNorm(mid_dim, LinearCrossAttention(mid_dim)) ) self.cross_attention_3 = Residual( PreNorm(mid_dim, LinearCrossAttention(mid_dim)) ) self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): is_last = ind >= (num_resolutions - 1) self.ups.append( nn.ModuleList( [ block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim), block_klass(dim_in, dim_in, time_emb_dim=time_dim), Residual(PreNorm(dim_in, LinearAttention(dim_in))), Upsample(dim_in) if not is_last else nn.Identity(), ] ) ) out_dim = default(out_dim, channels) self.final_conv = nn.Sequential( block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1) ) def forward(self, x, time, implicit_conditioning, explicit_conditioning): x = torch.cat((x, explicit_conditioning), dim=1) x = self.init_conv(x) conditioning = self.conditioning_init(implicit_conditioning) t = self.time_mlp(time) if exists(self.time_mlp) else None h = [] # conditioning encoder for block1, attn, downsample in self.conditioning_encoder: conditioning = block1(conditioning) conditioning = attn(conditioning) conditioning = downsample(conditioning) for block1, block2, attn, downsample in self.downs: x = block1(x, t) x = block2(x, t) x = attn(x) h.append(x) x = downsample(x) # reverse the c list # bottleneck x = self.cross_attention_1(x, conditioning) x = self.mid_block1(x, t) x = self.cross_attention_2(x, conditioning) x = self.mid_block2(x, t) x = self.cross_attention_3(x, conditioning) for block1, block2, attn, upsample in self.ups: x = torch.cat((x, h.pop()), dim=1) x = block1(x, t) x = block2(x, t) x = attn(x) x = upsample(x) return self.final_conv(x)