Spaces:
Runtime error
Runtime error
import torch as th | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import random | |
from .nn import timestep_embedding | |
from .unet import UNetModel | |
from .xf import LayerNorm, Transformer, convert_module_to_f16 | |
from timm.models.vision_transformer import PatchEmbed | |
class Text2ImModel(nn.Module): | |
def __init__( | |
self, | |
text_ctx, | |
xf_width, | |
xf_layers, | |
xf_heads, | |
xf_final_ln, | |
model_channels, | |
out_channels, | |
num_res_blocks, | |
attention_resolutions, | |
dropout, | |
channel_mult, | |
use_fp16, | |
num_heads, | |
num_heads_upsample, | |
num_head_channels, | |
use_scale_shift_norm, | |
resblock_updown, | |
in_channels = 3, | |
n_class = 3, | |
image_size = 64, | |
): | |
super().__init__() | |
self.encoder = Encoder(img_size=image_size, patch_size=image_size//16, in_chans=n_class, | |
xf_width=xf_width, xf_layers=8, xf_heads=xf_heads, model_channels=model_channels) | |
self.in_channels = in_channels | |
self.decoder = Text2ImUNet( | |
in_channels, | |
model_channels, | |
out_channels, | |
num_res_blocks, | |
attention_resolutions, | |
dropout=dropout, | |
channel_mult=channel_mult, | |
use_fp16=use_fp16, | |
num_heads=num_heads, | |
num_heads_upsample=num_heads_upsample, | |
num_head_channels=num_head_channels, | |
use_scale_shift_norm=use_scale_shift_norm, | |
resblock_updown=resblock_updown, | |
encoder_channels=xf_width | |
) | |
def forward(self, xt, timesteps, ref=None, uncond_p=0.0): | |
latent_outputs =self.encoder(ref, uncond_p) | |
pred = self.decoder(xt, timesteps, latent_outputs) | |
return pred | |
class Text2ImUNet(UNetModel): | |
def __init__( | |
self, | |
*args, | |
**kwargs, | |
): | |
super().__init__(*args, **kwargs) | |
self.transformer_proj = nn.Linear(512, self.model_channels * 4) ### | |
def forward(self, x, timesteps, latent_outputs): | |
hs = [] | |
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) | |
xf_proj, xf_out = latent_outputs["xf_proj"], latent_outputs["xf_out"] | |
xf_proj = self.transformer_proj(xf_proj) ### | |
emb = emb + xf_proj.to(emb) | |
h = x.type(self.dtype) | |
for module in self.input_blocks: | |
h = module(h, emb, xf_out) | |
hs.append(h) | |
h = self.middle_block(h, emb, xf_out) | |
for module in self.output_blocks: | |
h = th.cat([h, hs.pop()], dim=1) | |
h = module(h, emb, xf_out) | |
h = h.type(x.dtype) | |
h = self.out(h) | |
return h | |
class Encoder(nn.Module): | |
def __init__( | |
self, | |
img_size, | |
patch_size, | |
in_chans, | |
xf_width, | |
xf_layers, | |
xf_heads, | |
model_channels, | |
): | |
super().__init__( ) | |
self.transformer = Transformer( | |
xf_width, | |
xf_layers, | |
xf_heads, | |
) | |
self.cnn = CNN(in_chans) | |
self.final_ln = LayerNorm(xf_width) | |
self.cls_token = nn.Parameter(th.empty(1, 1, xf_width, dtype=th.float32)) | |
self.positional_embedding = nn.Parameter(th.empty(1, 256 + 1, xf_width, dtype=th.float32)) | |
def forward(self, ref, uncond_p=0.0): | |
x = self.cnn(ref) | |
x = x.flatten(2).transpose(1, 2) | |
x = x + self.positional_embedding[:, 1:, :] | |
cls_token = self.cls_token + self.positional_embedding[:, :1, :] | |
cls_tokens = cls_token.expand(x.shape[0], -1, -1) | |
x = th.cat((x, cls_tokens), dim=1) | |
xf_out = self.transformer(x) | |
if self.final_ln is not None: | |
xf_out = self.final_ln(xf_out) | |
xf_proj = xf_out[:, -1] | |
xf_out = xf_out[:, :-1].permute(0, 2, 1) # NLC -> NCL | |
outputs = dict(xf_proj=xf_proj, xf_out=xf_out) | |
return outputs | |
class SuperResText2ImModel(Text2ImModel): | |
""" | |
A text2im model that performs super-resolution. | |
Expects an extra kwarg `low_res` to condition on a low-resolution image. | |
""" | |
def __init__(self, *args, **kwargs): | |
if "in_channels" in kwargs: | |
kwargs = dict(kwargs) | |
kwargs["in_channels"] = kwargs["in_channels"] * 2 | |
else: | |
# Curse you, Python. Or really, just curse positional arguments :|. | |
args = list(args) | |
args[1] = args[1] * 2 | |
super().__init__(*args, **kwargs) | |
def forward(self, x, timesteps, low_res=None, **kwargs): | |
_, _, new_height, new_width = x.shape | |
upsampled = F.interpolate( | |
low_res, (new_height, new_width), mode="bilinear", align_corners=False | |
) | |
# ########## | |
# upsampled = upsampled + th.randn_like(upsampled)*0.0005*th.log(1 + 0.1* timesteps.reshape(timesteps.shape[0], 1,1,1)) | |
# ########## | |
x = th.cat([x, upsampled], dim=1) | |
return super().forward(x, timesteps, **kwargs) | |
def conv3x3(in_channels, out_channels, stride=1): | |
return nn.Conv2d(in_channels, out_channels, kernel_size=3, | |
stride=stride, padding=1, bias=True) | |
def conv7x7(in_channels, out_channels, stride=1): | |
return nn.Conv2d(in_channels, out_channels, kernel_size=7, | |
stride=stride, padding=3, bias=True) | |
class CNN(nn.Module): | |
def __init__(self, in_channels=3): | |
super(CNN, self).__init__() | |
self.conv1 = conv7x7(in_channels, 32) #256 | |
self.norm1 = nn.InstanceNorm2d(32, affine=True) | |
self.LReLU1 = nn.LeakyReLU(0.2) | |
self.conv2 = conv3x3(32, 64, 2) #128 | |
self.norm2 = nn.InstanceNorm2d(64, affine=True) | |
self.LReLU2 = nn.LeakyReLU(0.2) | |
self.conv3 = conv3x3(64, 128, 2) #64 | |
self.norm3 = nn.InstanceNorm2d(128, affine=True) | |
self.LReLU3 = nn.LeakyReLU(0.2) | |
self.conv4 = conv3x3(128, 256, 2) #32 | |
self.norm4 = nn.InstanceNorm2d(256, affine=True) | |
self.LReLU4 = nn.LeakyReLU(0.2) | |
self.conv5 = conv3x3(256, 512, 2) #16 | |
self.norm5 = nn.InstanceNorm2d(512, affine=True) | |
self.LReLU5 = nn.LeakyReLU(0.2) | |
self.conv6 = conv3x3(512, 512, 1) | |
def forward(self, x): | |
x = self.LReLU1(self.norm1(self.conv1(x))) | |
x = self.LReLU2(self.norm2(self.conv2(x))) | |
x = self.LReLU3(self.norm3(self.conv3(x))) | |
x = self.LReLU4(self.norm4(self.conv4(x))) | |
x = self.LReLU5(self.norm5(self.conv5(x))) | |
x = self.conv6(x) | |
return x |