import torch as th import torch.nn as nn import torch.nn.functional as F import random from .nn import timestep_embedding from .unet import UNetModel from .xf import LayerNorm, Transformer, convert_module_to_f16 from timm.models.vision_transformer import PatchEmbed class Text2ImModel(nn.Module): def __init__( self, text_ctx, xf_width, xf_layers, xf_heads, xf_final_ln, model_channels, out_channels, num_res_blocks, attention_resolutions, dropout, channel_mult, use_fp16, num_heads, num_heads_upsample, num_head_channels, use_scale_shift_norm, resblock_updown, in_channels = 3, n_class = 3, image_size = 64, ): super().__init__() self.encoder = Encoder(img_size=image_size, patch_size=image_size//16, in_chans=n_class, xf_width=xf_width, xf_layers=8, xf_heads=xf_heads, model_channels=model_channels) self.in_channels = in_channels self.decoder = Text2ImUNet( in_channels, model_channels, out_channels, num_res_blocks, attention_resolutions, dropout=dropout, channel_mult=channel_mult, use_fp16=use_fp16, num_heads=num_heads, num_heads_upsample=num_heads_upsample, num_head_channels=num_head_channels, use_scale_shift_norm=use_scale_shift_norm, resblock_updown=resblock_updown, encoder_channels=xf_width ) def forward(self, xt, timesteps, ref=None, uncond_p=0.0): latent_outputs =self.encoder(ref, uncond_p) pred = self.decoder(xt, timesteps, latent_outputs) return pred class Text2ImUNet(UNetModel): def __init__( self, *args, **kwargs, ): super().__init__(*args, **kwargs) self.transformer_proj = nn.Linear(512, self.model_channels * 4) ### def forward(self, x, timesteps, latent_outputs): hs = [] emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) xf_proj, xf_out = latent_outputs["xf_proj"], latent_outputs["xf_out"] xf_proj = self.transformer_proj(xf_proj) ### emb = emb + xf_proj.to(emb) h = x.type(self.dtype) for module in self.input_blocks: h = module(h, emb, xf_out) hs.append(h) h = self.middle_block(h, emb, xf_out) for module in self.output_blocks: h = th.cat([h, hs.pop()], dim=1) h = module(h, emb, xf_out) h = h.type(x.dtype) h = self.out(h) return h class Encoder(nn.Module): def __init__( self, img_size, patch_size, in_chans, xf_width, xf_layers, xf_heads, model_channels, ): super().__init__( ) self.transformer = Transformer( xf_width, xf_layers, xf_heads, ) self.cnn = CNN(in_chans) self.final_ln = LayerNorm(xf_width) self.cls_token = nn.Parameter(th.empty(1, 1, xf_width, dtype=th.float32)) self.positional_embedding = nn.Parameter(th.empty(1, 256 + 1, xf_width, dtype=th.float32)) def forward(self, ref, uncond_p=0.0): x = self.cnn(ref) x = x.flatten(2).transpose(1, 2) x = x + self.positional_embedding[:, 1:, :] cls_token = self.cls_token + self.positional_embedding[:, :1, :] cls_tokens = cls_token.expand(x.shape[0], -1, -1) x = th.cat((x, cls_tokens), dim=1) xf_out = self.transformer(x) if self.final_ln is not None: xf_out = self.final_ln(xf_out) xf_proj = xf_out[:, -1] xf_out = xf_out[:, :-1].permute(0, 2, 1) # NLC -> NCL outputs = dict(xf_proj=xf_proj, xf_out=xf_out) return outputs class SuperResText2ImModel(Text2ImModel): """ A text2im model that performs super-resolution. Expects an extra kwarg `low_res` to condition on a low-resolution image. """ def __init__(self, *args, **kwargs): if "in_channels" in kwargs: kwargs = dict(kwargs) kwargs["in_channels"] = kwargs["in_channels"] * 2 else: # Curse you, Python. Or really, just curse positional arguments :|. args = list(args) args[1] = args[1] * 2 super().__init__(*args, **kwargs) def forward(self, x, timesteps, low_res=None, **kwargs): _, _, new_height, new_width = x.shape upsampled = F.interpolate( low_res, (new_height, new_width), mode="bilinear", align_corners=False ) # ########## # upsampled = upsampled + th.randn_like(upsampled)*0.0005*th.log(1 + 0.1* timesteps.reshape(timesteps.shape[0], 1,1,1)) # ########## x = th.cat([x, upsampled], dim=1) return super().forward(x, timesteps, **kwargs) def conv3x3(in_channels, out_channels, stride=1): return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=True) def conv7x7(in_channels, out_channels, stride=1): return nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=stride, padding=3, bias=True) class CNN(nn.Module): def __init__(self, in_channels=3): super(CNN, self).__init__() self.conv1 = conv7x7(in_channels, 32) #256 self.norm1 = nn.InstanceNorm2d(32, affine=True) self.LReLU1 = nn.LeakyReLU(0.2) self.conv2 = conv3x3(32, 64, 2) #128 self.norm2 = nn.InstanceNorm2d(64, affine=True) self.LReLU2 = nn.LeakyReLU(0.2) self.conv3 = conv3x3(64, 128, 2) #64 self.norm3 = nn.InstanceNorm2d(128, affine=True) self.LReLU3 = nn.LeakyReLU(0.2) self.conv4 = conv3x3(128, 256, 2) #32 self.norm4 = nn.InstanceNorm2d(256, affine=True) self.LReLU4 = nn.LeakyReLU(0.2) self.conv5 = conv3x3(256, 512, 2) #16 self.norm5 = nn.InstanceNorm2d(512, affine=True) self.LReLU5 = nn.LeakyReLU(0.2) self.conv6 = conv3x3(512, 512, 1) def forward(self, x): x = self.LReLU1(self.norm1(self.conv1(x))) x = self.LReLU2(self.norm2(self.conv2(x))) x = self.LReLU3(self.norm3(self.conv3(x))) x = self.LReLU4(self.norm4(self.conv4(x))) x = self.LReLU5(self.norm5(self.conv5(x))) x = self.conv6(x) return x